Skip to content

Commit e1604ae

Browse files
committed
feat: add support for scanning slices of sql.Scanner structs
1 parent d9f273f commit e1604ae

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

internal/dbtest/db_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"path/filepath"
1212
"reflect"
1313
"runtime"
14+
"strconv"
1415
"strings"
1516
"testing"
1617
"time"
@@ -265,6 +266,7 @@ func TestDB(t *testing.T) {
265266
{testSelectNestedStructValue},
266267
{testSelectNestedStructPtr},
267268
{testSelectStructSlice},
269+
{testSelectScannerSlice},
268270
{testSelectSingleSlice},
269271
{testSelectMultiSlice},
270272
{testSelectJSONMap},
@@ -521,6 +523,63 @@ func testSelectStructSlice(t *testing.T, db *bun.DB) {
521523
}
522524
}
523525

526+
type CustomNum struct {
527+
Num int
528+
}
529+
530+
func (n *CustomNum) Scan(src any) error {
531+
switch val := src.(type) {
532+
case int32:
533+
*n = CustomNum{int(val)}
534+
case uint32:
535+
*n = CustomNum{int(val)}
536+
case int64:
537+
*n = CustomNum{int(val)}
538+
case uint64:
539+
*n = CustomNum{int(val)}
540+
case []byte:
541+
num, err := strconv.ParseInt(string(val), 10, 64)
542+
if err != nil {
543+
return err
544+
}
545+
*n = CustomNum{int(num)}
546+
case string:
547+
num, err := strconv.ParseInt(val, 10, 64)
548+
if err != nil {
549+
return err
550+
}
551+
*n = CustomNum{int(num)}
552+
default:
553+
return fmt.Errorf("unsupported type: %T", val)
554+
}
555+
return nil
556+
}
557+
558+
var _ sql.Scanner = (*CustomNum)(nil)
559+
560+
func testSelectScannerSlice(t *testing.T, db *bun.DB) {
561+
if !db.Dialect().Features().Has(feature.CTE) {
562+
t.Skip()
563+
}
564+
565+
values := db.NewValues(&[]map[string]interface{}{
566+
{"column1": 1},
567+
{"column1": 2},
568+
{"column1": 3},
569+
})
570+
571+
var ns []CustomNum
572+
err := db.NewSelect().
573+
With("t", values).
574+
TableExpr("t").
575+
Scan(ctx, &ns)
576+
require.NoError(t, err)
577+
require.Len(t, ns, 3)
578+
for i, n := range ns {
579+
require.Equal(t, i+1, n.Num)
580+
}
581+
}
582+
524583
func testSelectSingleSlice(t *testing.T, db *bun.DB) {
525584
if !db.Dialect().Features().Has(feature.CTE) {
526585
t.Skip()

model.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ import (
1414
var errNilModel = errors.New("bun: Model(nil)")
1515

1616
var (
17-
timeType = reflect.TypeFor[time.Time]()
18-
bytesType = reflect.TypeFor[[]byte]()
17+
timeType = reflect.TypeFor[time.Time]()
18+
bytesType = reflect.TypeFor[[]byte]()
19+
scannerType = reflect.TypeFor[sql.Scanner]()
1920
)
2021

2122
type Model = schema.Model
@@ -125,7 +126,7 @@ func _newModel(db *DB, dest any, scan bool) (Model, error) {
125126
case reflect.Slice:
126127
switch elemType := sliceElemType(v); elemType.Kind() {
127128
case reflect.Struct:
128-
if elemType != timeType {
129+
if elemType != timeType && !reflect.PointerTo(elemType).Implements(scannerType) {
129130
return newSliceTableModel(db, dest, v, elemType), nil
130131
}
131132
case reflect.Map:

0 commit comments

Comments
 (0)