Skip to content

Commit ae120dd

Browse files
dmakushinDmitrii Makushin
authored and
Dmitrii Makushin
committed
Add RunInTx method for DB
1 parent 992acfb commit ae120dd

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

go.sum

-2
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,6 @@ github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w=
109109
github.com/spf13/cast v1.5.0/go.mod h1:SpXXQ5YoyJw6s3/6cMTQuxvgRl3PCJiyaX9p6b155UU=
110110
github.com/stephenafamo/fakedb v0.0.0-20221230081958-0b86f816ed97 h1:XItoZNmhOih06TC02jK7l3wlpZ0XT/sPQYutDcGOQjg=
111111
github.com/stephenafamo/fakedb v0.0.0-20221230081958-0b86f816ed97/go.mod h1:bM3Vmw1IakoaXocHmMIGgJFYob0vuK+CFWiJHQvz0jQ=
112-
github.com/stephenafamo/scan v0.6.0 h1:N0joyP/wriC9VvP6w9SDxHIuQGatW4c2YW7Z5L4m45s=
113-
github.com/stephenafamo/scan v0.6.0/go.mod h1:FhIUJ8pLNyex36xGFiazDJJ5Xry0UkAi+RkWRrEcRMg=
114112
github.com/stephenafamo/scan v0.6.1 h1:nXokGCQwYazMuyvdNAoK0T8Z76FWcpMvDdtengpz6PU=
115113
github.com/stephenafamo/scan v0.6.1/go.mod h1:FhIUJ8pLNyex36xGFiazDJJ5Xry0UkAi+RkWRrEcRMg=
116114
github.com/stephenafamo/sqlparser v0.0.0-20241111104950-b04fa8a26c9c h1:JFga++XBnZG2xlnvQyHJkeBWZ9G9mGdtgvLeSRbp/BA=

stdlib.go

+28
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"database/sql"
66
"database/sql/driver"
7+
"errors"
8+
"fmt"
79

810
"github.com/stephenafamo/scan"
911
"github.com/stephenafamo/scan/stdscan"
@@ -96,6 +98,32 @@ func (d DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (Tx, error) {
9698
return NewTx(tx), nil
9799
}
98100

101+
// RunInTx runs the provided function in a transaction.
102+
// If the function returns an error, the transaction is rolled back.
103+
// Otherwise, the transaction is committed.
104+
func (d DB) RunInTx(ctx context.Context, txOptions *sql.TxOptions, fn func(context.Context, Tx) error) error {
105+
tx, err := d.BeginTx(ctx, txOptions)
106+
if err != nil {
107+
return fmt.Errorf("begin transaction: %w", err)
108+
}
109+
110+
if err := fn(ctx, tx); err != nil {
111+
err = fmt.Errorf("call method in transaction: %w", err)
112+
113+
if rollbackErr := tx.Rollback(); rollbackErr != nil {
114+
return errors.Join(err, rollbackErr)
115+
}
116+
117+
return err
118+
}
119+
120+
if err := tx.Commit(); err != nil {
121+
return fmt.Errorf("commit transaction: %w", err)
122+
}
123+
124+
return nil
125+
}
126+
99127
// NewTx wraps an [*sql.Tx] and returns a type that implements [Queryer] but still
100128
// retains the expected methods used by *sql.Tx
101129
// This is useful when an existing *sql.Tx is used in other places in the codebase

0 commit comments

Comments
 (0)