Skip to content

Commit 57e2ecd

Browse files
committed
feat: add support for scanning slices of sql.Scanner structs
1 parent 4123360 commit 57e2ecd

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

internal/dbtest/db_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ func TestDB(t *testing.T) {
266266
{testSelectNestedStructValue},
267267
{testSelectNestedStructPtr},
268268
{testSelectStructSlice},
269+
{testSelectScannerSlice},
269270
{testSelectSingleSlice},
270271
{testSelectMultiSlice},
271272
{testSelectJSONMap},
@@ -521,6 +522,51 @@ func testSelectStructSlice(t *testing.T, db *bun.DB) {
521522
}
522523
}
523524

525+
type CustomNum struct {
526+
Num int
527+
}
528+
529+
func (n *CustomNum) Scan(src any) error {
530+
switch val := src.(type) {
531+
case int32:
532+
*n = CustomNum{int(val)}
533+
case uint32:
534+
*n = CustomNum{int(val)}
535+
case int64:
536+
*n = CustomNum{int(val)}
537+
case uint64:
538+
*n = CustomNum{int(val)}
539+
default:
540+
return fmt.Errorf("unsupported type: %T", val)
541+
}
542+
return nil
543+
}
544+
545+
var _ sql.Scanner = (*CustomNum)(nil)
546+
547+
func testSelectScannerSlice(t *testing.T, db *bun.DB) {
548+
if !db.Dialect().Features().Has(feature.CTE) {
549+
t.Skip()
550+
}
551+
552+
values := db.NewValues(&[]map[string]interface{}{
553+
{"column1": 1},
554+
{"column1": 2},
555+
{"column1": 3},
556+
})
557+
558+
var ns []CustomNum
559+
err := db.NewSelect().
560+
With("t", values).
561+
TableExpr("t").
562+
Scan(ctx, &ns)
563+
require.NoError(t, err)
564+
require.Len(t, ns, 3)
565+
for i, n := range ns {
566+
require.Equal(t, i+1, n.Num)
567+
}
568+
}
569+
524570
func testSelectSingleSlice(t *testing.T, db *bun.DB) {
525571
if !db.Dialect().Features().Has(feature.CTE) {
526572
t.Skip()

model.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@ import (
1414
var errNilModel = errors.New("bun: Model(nil)")
1515

1616
var (
17-
timeType = reflect.TypeFor[time.Time]()
18-
bytesType = reflect.TypeFor[[]byte]()
17+
timeType = reflect.TypeFor[time.Time]()
18+
bytesType = reflect.TypeFor[[]byte]()
19+
scannerType = reflect.TypeFor[sql.Scanner]()
1920
)
2021

2122
type Model = schema.Model
@@ -125,7 +126,7 @@ func _newModel(db *DB, dest interface{}, scan bool) (Model, error) {
125126
case reflect.Slice:
126127
switch elemType := sliceElemType(v); elemType.Kind() {
127128
case reflect.Struct:
128-
if elemType != timeType {
129+
if elemType != timeType && !reflect.PointerTo(elemType).Implements(scannerType) {
129130
return newSliceTableModel(db, dest, v, elemType), nil
130131
}
131132
case reflect.Map:

0 commit comments

Comments
 (0)