Skip to content

Commit

Permalink
support calling query sequences multiples times
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <achille.roussel@gmail.com>
  • Loading branch information
achille-roussel committed Feb 2, 2024
1 parent 6698ca7 commit 3213ba9
Showing 1 changed file with 34 additions and 34 deletions.
68 changes: 34 additions & 34 deletions sqlrange.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ func Query[Row any](q Queryable, query string, args ...any) iter.Seq2[Row, error
// QueryContext returns the results of the query as a sequence of rows.
//
// The returned function automatically closes the unerlying sql.Rows value when
// it completes its iteration. The function can only be iterated once, it will
// not retain the values that it has seen.
// it completes its iteration.
//
// A typical use of QueryContext is:
//
Expand All @@ -210,21 +209,20 @@ func Query[Row any](q Queryable, query string, args ...any) iter.Seq2[Row, error
// See Scan for more information about how the rows are mapped to the row type
// parameter Row.
func QueryContext[Row any](ctx context.Context, q Queryable, query string, args ...any) iter.Seq2[Row, error] {
rows, err := q.QueryContext(ctx, query, args...)
if err != nil {
return func(yield func(Row, error) bool) {
return func(yield func(Row, error) bool) {
if rows, err := q.QueryContext(ctx, query, args...); err != nil {
var zero Row
yield(zero, err)
} else {
scan[Row](yield, rows)
}
}
return Scan[Row](rows)
}

// Scan returns a sequence of rows from a sql.Rows value.
//
// The returned function automatically closes the rows passed as argument when
// it completes its iteration. The function can only be iterated once, it will
// not retain the values that it has seen.
// it completes its iteration.
//
// A typical use of Scan is:
//
Expand Down Expand Up @@ -254,40 +252,42 @@ func QueryContext[Row any](ctx context.Context, q Queryable, query string, args
// Ranging over the returned function will panic if the type parameter is not a
// struct.
func Scan[Row any](rows *sql.Rows) iter.Seq2[Row, error] {
return func(yield func(Row, error) bool) {
defer rows.Close()
var zero Row
return func(yield func(Row, error) bool) { scan(yield, rows) }
}

columns, err := rows.Columns()
if err != nil {
yield(zero, err)
return
}
func scan[Row any](yield func(Row, error) bool, rows *sql.Rows) {
defer rows.Close()
var zero Row

scanArgs := make([]any, len(columns))
row := new(Row)
val := reflect.ValueOf(row).Elem()
columns, err := rows.Columns()
if err != nil {
yield(zero, err)
return
}

for columnName, structField := range Fields(val.Type()) {
if columnIndex := slices.Index(columns, columnName); columnIndex >= 0 {
scanArgs[columnIndex] = val.FieldByIndex(structField.Index).Addr().Interface()
}
}
scanArgs := make([]any, len(columns))
row := new(Row)
val := reflect.ValueOf(row).Elem()

for rows.Next() {
if err := rows.Scan(scanArgs...); err != nil {
yield(zero, err)
return
}
if !yield(*row, nil) {
return
}
*row = zero
for columnName, structField := range Fields(val.Type()) {
if columnIndex := slices.Index(columns, columnName); columnIndex >= 0 {
scanArgs[columnIndex] = val.FieldByIndex(structField.Index).Addr().Interface()
}
}

if err := rows.Err(); err != nil {
for rows.Next() {
if err := rows.Scan(scanArgs...); err != nil {
yield(zero, err)
return
}
if !yield(*row, nil) {
return
}
*row = zero
}

if err := rows.Err(); err != nil {
yield(zero, err)
}
}

Expand Down

0 comments on commit 3213ba9

Please sign in to comment.