Skip to content

Commit 4196d19

Browse files
authored
refactor(sql): simplify Iter API to use iter.Seq2[T, error] (#536)
Refactored the RowsIter design to provide a cleaner and more idiomatic API: - Removed RowsIter struct and Err() method - Changed Iter() to directly return (iter.Seq2[T, error], error) - Errors are now yielded during iteration instead of checked afterwards - Added rows.Err() check after iteration completes - Updated all tests to use the new API signature Breaking Changes: - Iter() now returns (iter.Seq2[T, error], error) instead of *RowsIter[T] - Users must handle errors in the range loop: for v, err := range seq Benefits: - More concise API with less boilerplate - Immediate error handling in iteration loop - Follows Go's standard error handling patterns - Better suited for use in shortcuts Added: - QueryIterContext() shortcut function with automatic row cleanup Files changed: - sql/binder.go: Simplified Iter implementation (-78 lines) - sql/binder_test.go: Updated tests for new API - bind.go: Updated documentation and function signature - shortcuts.go: Added QueryIterContext helper
1 parent aaf013c commit 4196d19

File tree

4 files changed

+65
-78
lines changed

4 files changed

+65
-78
lines changed

bind.go

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@ limitations under the License.
1616

1717
package juice
1818

19-
import "github.com/go-juicedev/juice/sql"
19+
import (
20+
"iter"
21+
22+
"github.com/go-juicedev/juice/sql"
23+
)
2024

2125
// BindWithResultMap binds database query results to a value of type T using a custom ResultMap.
2226
// This function provides backward compatibility for code that imports the juice package directly.
@@ -115,22 +119,26 @@ func List2[T any](rows sql.Rows) ([]*T, error) {
115119
// For new code, consider using sql.Iter directly:
116120
//
117121
// import "github.com/go-juicedev/juice/sql"
118-
// iter := sql.Iter[User](rows)
122+
// seq, err := sql.Iter[User](rows)
119123
//
120-
// The iterator implements Go's iter.Seq[T] interface, allowing use in range loops:
124+
// The function returns an iter.Seq2[T, error] that yields (value, error) pairs,
125+
// allowing for immediate error handling in the loop:
121126
//
122-
// iter := Iter[User](rows)
123-
// for user := range iter.Iter() {
127+
// seq, err := Iter[User](rows)
128+
// if err != nil {
129+
// return err
130+
// }
131+
// for user, err := range seq {
132+
// if err != nil {
133+
// return fmt.Errorf("iteration failed: %w", err)
134+
// }
124135
// // Process each user
125136
// fmt.Println(user.Name)
126137
// }
127-
// if err := iter.Err(); err != nil {
128-
// // Handle iteration error
129-
// }
130138
//
131139
// Note: The caller is responsible for closing the rows when iteration is complete.
132140
//
133-
// Returns an iterator that yields values of type T.
134-
func Iter[T any](rows sql.Rows) *sql.RowsIter[T] {
141+
// Returns an iterator that yields (value, error) pairs of type T, and any initialization error.
142+
func Iter[T any](rows sql.Rows) (iter.Seq2[T, error], error) {
135143
return sql.Iter[T](rows)
136144
}

shortcuts.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package juice
1919
import (
2020
"context"
2121
"database/sql"
22+
"iter"
2223

2324
sqllib "github.com/go-juicedev/juice/sql"
2425
)
@@ -64,3 +65,16 @@ func QueryList2Context[T any](ctx context.Context, statement, param any) (result
6465
defer func() { _ = rows.Close() }()
6566
return sqllib.List2[T](rows)
6667
}
68+
69+
// QueryIterContext executes a query and returns an iterator over T.
70+
// Rows are automatically closed when iteration completes or stops.
71+
// (ctx must contain a Manager via ManagerFromContext)
72+
func QueryIterContext[T any](ctx context.Context, statement, param any) (iter.Seq2[T, error], error) {
73+
manager := ManagerFromContext(ctx)
74+
rows, err := manager.Object(statement).QueryContext(ctx, param)
75+
if err != nil {
76+
return nil, err
77+
}
78+
defer func() { _ = rows.Close() }()
79+
return sqllib.Iter[T](rows)
80+
}

sql/binder.go

Lines changed: 17 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ package sql
1818

1919
import (
2020
"database/sql"
21-
"errors"
2221
"iter"
2322
"reflect"
2423
"time"
@@ -180,46 +179,16 @@ func List2[T any](rows Rows) ([]*T, error) {
180179
return result, nil
181180
}
182181

183-
// RowsIter provides an iterator interface for Rows.
184-
// It implements Go's built-in iter.Seq interface for type-safe iteration over database rows.
185-
// Type parameter T represents the type of values that will be yielded during iteration.
186-
type RowsIter[T any] struct {
187-
rows Rows // The underlying Rows to iterate over
188-
err error // Stores any error that occurs during iteration
189-
}
190-
191-
// Err returns any error that occurred during iteration.
192-
// This method should be checked after iteration is complete to ensure
193-
// no errors occurred while processing the rows.
194-
func (r *RowsIter[T]) Err() error {
195-
return errors.Join(r.err, r.rows.Err())
196-
}
197-
198-
// Iter implements the iter.Seq interface for row iteration.
199-
// It yields values of type T, automatically handling memory allocation
200-
// and type conversion for each row.
201-
//
202-
// Example usage:
203-
//
204-
// iter := Iter[User](rows)
205-
// for v := range iter.Iter() {
206-
// // Process each user
207-
// fmt.Println(v.Name)
208-
// }
209-
// if err := iter.Err(); err != nil {
210-
// // Handle error
211-
// }
212-
func (r *RowsIter[T]) Iter() iter.Seq[T] {
213-
columns, err := r.rows.Columns()
182+
func Iter[T any](rows Rows) (iter.Seq2[T, error], error) {
183+
columns, err := rows.Columns()
214184
if err != nil {
215-
r.err = err
216-
return func(func(T) bool) {}
185+
return nil, err
217186
}
187+
218188
columnDest := &rowDestination{}
219189
t := reflect.TypeFor[T]()
220190

221-
// Default object factory for non-pointer types
222-
var objectFactory = func() T { return *new(T) }
191+
var objectFactory func() T
223192

224193
isPtr := t.Kind() == reflect.Ptr
225194

@@ -229,6 +198,8 @@ func (r *RowsIter[T]) Iter() iter.Seq[T] {
229198
result, _ := reflect.TypeAssert[T](reflect.New(t.Elem()))
230199
return result
231200
}
201+
} else {
202+
objectFactory = func() T { return *new(T) }
232203
}
233204

234205
// handler encapsulates the row scanning logic and object creation
@@ -248,35 +219,23 @@ func (r *RowsIter[T]) Iter() iter.Seq[T] {
248219
if err != nil {
249220
return t, err
250221
}
251-
if err = r.rows.Scan(dest...); err != nil {
222+
if err = rows.Scan(dest...); err != nil {
252223
return t, err
253224
}
254225
return t, nil
255226
}
256227

257-
return func(yield func(T) bool) {
258-
259-
for r.rows.Next() {
228+
return func(yield func(T, error) bool) {
229+
for rows.Next() {
260230
value, err := handler()
261-
if err != nil {
262-
r.err = err
263-
return
264-
}
265-
if !yield(value) {
231+
if !yield(value, err) {
266232
return
267233
}
268234
}
269-
}
270-
}
271-
272-
// Iter creates an iterator over SQL rows that yields values of type T.
273-
// It handles both pointer and non-pointer types automatically and provides
274-
// proper memory management for each iteration.
275-
//
276-
// Note: This function does not close the Rows. The caller is responsible
277-
// for closing the rows when iteration is complete. This design allows for more
278-
// flexible resource management, especially when using the iterator in different
279-
// contexts or when early termination is needed.
280-
func Iter[T any](rows Rows) *RowsIter[T] {
281-
return &RowsIter[T]{rows: rows}
235+
// Check for any errors that occurred during iteration
236+
if err := rows.Err(); err != nil {
237+
var zero T
238+
yield(zero, err)
239+
}
240+
}, nil
282241
}

sql/binder_test.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,15 @@ func TestIter(t *testing.T) {
269269
Data: [][]any{{1, "Alice"}, {2, "Bob"}},
270270
}
271271
var users []TestUser
272-
iter := Iter[TestUser](rows)
273-
for user := range iter.Iter() {
274-
users = append(users, user)
272+
seq, err := Iter[TestUser](rows)
273+
if err != nil {
274+
t.Fatalf("Iter initialization failed: %v", err)
275275
}
276-
if err := iter.Err(); err != nil {
277-
t.Fatalf("Iter failed: %v", err)
276+
for user, err := range seq {
277+
if err != nil {
278+
t.Fatalf("Iter failed: %v", err)
279+
}
280+
users = append(users, user)
278281
}
279282
if len(users) != len(rows.Data) {
280283
t.Fatalf("Expected %d users, got %d", len(rows.Data), len(users))
@@ -291,14 +294,17 @@ func TestIter(t *testing.T) {
291294
ColumnsLine: []string{"id", "name"},
292295
Data: [][]any{},
293296
}
294-
iter := Iter[TestUser](rows)
297+
seq, err := Iter[TestUser](rows)
298+
if err != nil {
299+
t.Fatalf("Iter initialization failed: %v", err)
300+
}
295301
var users []TestUser
296-
for user := range iter.Iter() {
302+
for user, err := range seq {
303+
if err != nil {
304+
t.Fatalf("Iter failed with empty rows: %v", err)
305+
}
297306
users = append(users, user)
298307
}
299-
if err := iter.Err(); err != nil {
300-
t.Fatalf("Iter failed with empty rows: %v", err)
301-
}
302308
if len(users) != 0 {
303309
t.Errorf("Expected empty slice, got %d users", len(users))
304310
}

0 commit comments

Comments
 (0)