@@ -15,6 +15,8 @@ import (
1515 "time"
1616
1717 "github.com/uptrace/bun/internal"
18+ semconv "go.opentelemetry.io/otel/semconv/v1.21.0"
19+ "go.opentelemetry.io/otel/trace"
1820)
1921
2022func init () {
@@ -68,38 +70,38 @@ func (d Driver) Open(name string) (driver.Conn, error) {
6870//------------------------------------------------------------------------------
6971
7072type Connector struct {
71- cfg * Config
73+ conf * Config
7274}
7375
7476func NewConnector (opts ... Option ) * Connector {
75- c := & Connector {cfg : newDefaultConfig ()}
77+ c := & Connector {conf : newDefaultConfig ()}
7678 for _ , opt := range opts {
77- opt (c .cfg )
79+ opt (c .conf )
7880 }
7981 return c
8082}
8183
8284var _ driver.Connector = (* Connector )(nil )
8385
8486func (c * Connector ) Connect (ctx context.Context ) (driver.Conn , error ) {
85- if err := c .cfg .verify (); err != nil {
87+ if err := c .conf .verify (); err != nil {
8688 return nil , err
8789 }
88- return newConn (ctx , c .cfg )
90+ return newConn (ctx , c .conf )
8991}
9092
9193func (c * Connector ) Driver () driver.Driver {
9294 return Driver {connector : c }
9395}
9496
9597func (c * Connector ) Config () * Config {
96- return c .cfg
98+ return c .conf
9799}
98100
99101//------------------------------------------------------------------------------
100102
101103type Conn struct {
102- cfg * Config
104+ conf * Config
103105
104106 netConn net.Conn
105107 rd * reader
@@ -112,20 +114,20 @@ type Conn struct {
112114 closed int32
113115}
114116
115- func newConn (ctx context.Context , cfg * Config ) (* Conn , error ) {
116- netConn , err := cfg .Dialer (ctx , cfg .Network , cfg .Addr )
117+ func newConn (ctx context.Context , conf * Config ) (* Conn , error ) {
118+ netConn , err := conf .Dialer (ctx , conf .Network , conf .Addr )
117119 if err != nil {
118120 return nil , err
119121 }
120122
121123 cn := & Conn {
122- cfg : cfg ,
124+ conf : conf ,
123125 netConn : netConn ,
124126 rd : newReader (netConn ),
125127 }
126128
127- if cfg .TLSConfig != nil {
128- if err := enableSSL (ctx , cn , cfg .TLSConfig ); err != nil {
129+ if conf .TLSConfig != nil {
130+ if err := enableSSL (ctx , cn , conf .TLSConfig ); err != nil {
129131 return nil , err
130132 }
131133 }
@@ -134,7 +136,7 @@ func newConn(ctx context.Context, cfg *Config) (*Conn, error) {
134136 return nil , err
135137 }
136138
137- for k , v := range cfg .ConnParams {
139+ for k , v := range conf .ConnParams {
138140 if v != nil {
139141 _ , err = cn .ExecContext (ctx , fmt .Sprintf ("SET %s TO $1" , k ), []driver.NamedValue {
140142 {Value : v },
@@ -150,6 +152,17 @@ func newConn(ctx context.Context, cfg *Config) (*Conn, error) {
150152 return cn , nil
151153}
152154
155+ func (cn * Conn ) Close () error {
156+ if ! atomic .CompareAndSwapInt32 (& cn .closed , 0 , 1 ) {
157+ return nil
158+ }
159+ return cn .netConn .Close ()
160+ }
161+
162+ func (cn * Conn ) isClosed () bool {
163+ return atomic .LoadInt32 (& cn .closed ) == 1
164+ }
165+
153166func (cn * Conn ) reader (ctx context.Context , timeout time.Duration ) * reader {
154167 cn .setReadDeadline (ctx , timeout )
155168 return cn .rd
@@ -174,11 +187,16 @@ func (cn *Conn) write(ctx context.Context, wb *writeBuffer) error {
174187var _ driver.Conn = (* Conn )(nil )
175188
176189func (cn * Conn ) Prepare (query string ) (driver.Stmt , error ) {
190+ return cn .PrepareContext (context .Background (), query )
191+ }
192+
193+ var _ driver.ConnPrepareContext = (* Conn )(nil )
194+
195+ func (cn * Conn ) PrepareContext (ctx context.Context , query string ) (driver.Stmt , error ) {
177196 if cn .isClosed () {
178197 return nil , driver .ErrBadConn
179198 }
180-
181- ctx := context .TODO ()
199+ cn .trace (ctx )
182200
183201 name := fmt .Sprintf ("pgdriver-%d" , cn .stmtCount )
184202 cn .stmtCount ++
@@ -195,32 +213,29 @@ func (cn *Conn) Prepare(query string) (driver.Stmt, error) {
195213 return newStmt (cn , name , rowDesc ), nil
196214}
197215
198- func (cn * Conn ) Close () error {
199- if ! atomic .CompareAndSwapInt32 (& cn .closed , 0 , 1 ) {
200- return nil
201- }
202- return cn .netConn .Close ()
203- }
204-
205- func (cn * Conn ) isClosed () bool {
206- return atomic .LoadInt32 (& cn .closed ) == 1
207- }
208-
209216func (cn * Conn ) Begin () (driver.Tx , error ) {
210217 return cn .BeginTx (context .Background (), driver.TxOptions {})
211218}
212219
213220var _ driver.ConnBeginTx = (* Conn )(nil )
214221
215222func (cn * Conn ) BeginTx (ctx context.Context , opts driver.TxOptions ) (driver.Tx , error ) {
223+ if cn .isClosed () {
224+ return nil , driver .ErrBadConn
225+ }
226+ cn .trace (ctx )
227+
216228 // No need to check if the conn is closed. ExecContext below handles that.
217229 isolation := sql .IsolationLevel (opts .Isolation )
218230
219231 var command string
220232 switch isolation {
221233 case sql .LevelDefault :
222234 command = "BEGIN"
223- case sql .LevelReadUncommitted , sql .LevelReadCommitted , sql .LevelRepeatableRead , sql .LevelSerializable :
235+ case sql .LevelReadUncommitted ,
236+ sql .LevelReadCommitted ,
237+ sql .LevelRepeatableRead ,
238+ sql .LevelSerializable :
224239 command = fmt .Sprintf ("BEGIN; SET TRANSACTION ISOLATION LEVEL %s" , isolation .String ())
225240 default :
226241 return nil , fmt .Errorf ("pgdriver: unsupported transaction isolation: %s" , isolation .String ())
@@ -244,6 +259,8 @@ func (cn *Conn) ExecContext(
244259 if cn .isClosed () {
245260 return nil , driver .ErrBadConn
246261 }
262+ cn .trace (ctx )
263+
247264 res , err := cn .exec (ctx , query , args )
248265 if err != nil {
249266 return nil , cn .checkBadConn (err )
@@ -272,6 +289,8 @@ func (cn *Conn) QueryContext(
272289 if cn .isClosed () {
273290 return nil , driver .ErrBadConn
274291 }
292+ cn .trace (ctx )
293+
275294 rows , err := cn .query (ctx , query , args )
276295 if err != nil {
277296 return nil , cn .checkBadConn (err )
@@ -301,14 +320,14 @@ func (cn *Conn) Ping(ctx context.Context) error {
301320
302321func (cn * Conn ) setReadDeadline (ctx context.Context , timeout time.Duration ) {
303322 if timeout == - 1 {
304- timeout = cn .cfg .ReadTimeout
323+ timeout = cn .conf .ReadTimeout
305324 }
306325 _ = cn .netConn .SetReadDeadline (cn .deadline (ctx , timeout ))
307326}
308327
309328func (cn * Conn ) setWriteDeadline (ctx context.Context , timeout time.Duration ) {
310329 if timeout == - 1 {
311- timeout = cn .cfg .WriteTimeout
330+ timeout = cn .conf .WriteTimeout
312331 }
313332 _ = cn .netConn .SetWriteDeadline (cn .deadline (ctx , timeout ))
314333}
@@ -343,8 +362,8 @@ func (cn *Conn) ResetSession(ctx context.Context) error {
343362 if cn .isClosed () {
344363 return driver .ErrBadConn
345364 }
346- if cn .cfg .ResetSessionFunc != nil {
347- return cn .cfg .ResetSessionFunc (ctx , cn )
365+ if cn .conf .ResetSessionFunc != nil {
366+ return cn .conf .ResetSessionFunc (ctx , cn )
348367 }
349368 return nil
350369}
@@ -360,6 +379,16 @@ func (cn *Conn) checkBadConn(err error) error {
360379
361380func (cn * Conn ) Conn () net.Conn { return cn .netConn }
362381
382+ func (cn * Conn ) trace (ctx context.Context ) {
383+ if span := trace .SpanFromContext (ctx ); span .IsRecording () {
384+ span .SetAttributes (
385+ semconv .DBUserKey .String (cn .conf .User ),
386+ semconv .DBNameKey .String (cn .conf .Database ),
387+ semconv .ServerAddressKey .String (cn .conf .Addr ),
388+ )
389+ }
390+ }
391+
363392//------------------------------------------------------------------------------
364393
365394type rows struct {
0 commit comments