Skip to content

Commit 7f75b12

Browse files
authored
Generate unique savepoint names for nested transactions (#7174)
* Generate unique savepoint names * Add a test for deeply nested wrapped transactions
1 parent 0daaf17 commit 7f75b12

File tree

3 files changed

+74
-4
lines changed

3 files changed

+74
-4
lines changed

finisher_api.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"database/sql"
55
"errors"
66
"fmt"
7+
"hash/maphash"
78
"reflect"
89
"strings"
910

@@ -623,14 +624,15 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er
623624
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
624625
// nested transaction
625626
if !db.DisableNestedTransaction {
626-
err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error
627+
spID := new(maphash.Hash).Sum64()
628+
err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error
627629
if err != nil {
628630
return
629631
}
630632
defer func() {
631633
// Make sure to rollback when panic, Block error or Commit error
632634
if panicked || err != nil {
633-
db.RollbackTo(fmt.Sprintf("sp%p", fc))
635+
db.RollbackTo(fmt.Sprintf("sp%d", spID))
634636
}
635637
}()
636638
}

tests/go.mod

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ require (
2929
github.com/microsoft/go-mssqldb v1.7.2 // indirect
3030
github.com/pmezard/go-difflib v1.0.0 // indirect
3131
github.com/rogpeppe/go-internal v1.12.0 // indirect
32-
golang.org/x/crypto v0.24.0 // indirect
33-
golang.org/x/text v0.16.0 // indirect
32+
golang.org/x/crypto v0.26.0 // indirect
33+
golang.org/x/text v0.17.0 // indirect
3434
gopkg.in/yaml.v3 v3.0.1 // indirect
3535
)
3636

tests/transaction_test.go

+68
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,74 @@ func TestNestedTransactionWithBlock(t *testing.T) {
297297
}
298298
}
299299

300+
func TestDeeplyNestedTransactionWithBlockAndWrappedCallback(t *testing.T) {
301+
transaction := func(ctx context.Context, db *gorm.DB, callback func(ctx context.Context, db *gorm.DB) error) error {
302+
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
303+
return callback(ctx, tx)
304+
})
305+
}
306+
var (
307+
user = *GetUser("transaction-nested", Config{})
308+
user1 = *GetUser("transaction-nested-1", Config{})
309+
user2 = *GetUser("transaction-nested-2", Config{})
310+
)
311+
312+
if err := transaction(context.Background(), DB, func(ctx context.Context, tx *gorm.DB) error {
313+
tx.Create(&user)
314+
315+
if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
316+
t.Fatalf("Should find saved record")
317+
}
318+
319+
if err := transaction(ctx, tx, func(ctx context.Context, tx1 *gorm.DB) error {
320+
tx1.Create(&user1)
321+
322+
if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil {
323+
t.Fatalf("Should find saved record")
324+
}
325+
326+
if err := transaction(ctx, tx1, func(ctx context.Context, tx2 *gorm.DB) error {
327+
tx2.Create(&user2)
328+
329+
if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil {
330+
t.Fatalf("Should find saved record")
331+
}
332+
333+
return errors.New("inner rollback")
334+
}); err == nil {
335+
t.Fatalf("nested transaction has no error")
336+
}
337+
338+
return errors.New("rollback")
339+
}); err == nil {
340+
t.Fatalf("nested transaction should returns error")
341+
}
342+
343+
if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil {
344+
t.Fatalf("Should not find rollbacked record")
345+
}
346+
347+
if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil {
348+
t.Fatalf("Should find saved record")
349+
}
350+
return nil
351+
}); err != nil {
352+
t.Fatalf("no error should return, but got %v", err)
353+
}
354+
355+
if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil {
356+
t.Fatalf("Should find saved record")
357+
}
358+
359+
if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil {
360+
t.Fatalf("Should not find rollbacked parent record")
361+
}
362+
363+
if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil {
364+
t.Fatalf("Should not find rollbacked nested record")
365+
}
366+
}
367+
300368
func TestDisabledNestedTransaction(t *testing.T) {
301369
var (
302370
user = *GetUser("transaction-nested", Config{})

0 commit comments

Comments
 (0)