From 04965ab5e7f2490bcb09b2b3242de4a46f2b1043 Mon Sep 17 00:00:00 2001 From: Stefan Majewsky Date: Sat, 11 Apr 2026 00:32:22 +0200 Subject: reduce function body of Select() --- internal/plan.go | 11 +++++++++-- query.go | 53 +++++++++++++++++++++++++++++++++++++---------------- 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/internal/plan.go b/internal/plan.go index 8fc24d8..f738278 100644 --- a/internal/plan.go +++ b/internal/plan.go @@ -46,9 +46,16 @@ var ( // BuildPlan creates a new plan for the given struct type. func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) { + p, err := buildPlan(t, dialect) + if err != nil { + return Plan{}, fmt.Errorf("cannot use type %s.%s for queries: %w", t.PkgPath(), t.Name(), err) + } + return p, nil +} + +func buildPlan(t reflect.Type, dialect Dialect) (Plan, error) { if t.Kind() != reflect.Struct { - return Plan{}, fmt.Errorf("expected record type to be a struct, but got kind %s (full type: %s.%s)", - t.Kind(), t.PkgPath(), t.Name()) + return Plan{}, fmt.Errorf("expected struct type, but got kind %s", t.Kind().String()) } var p = Plan{ diff --git a/query.go b/query.go index 95c1f59..03b87b6 100644 --- a/query.go +++ b/query.go @@ -4,8 +4,11 @@ package oblast import ( + "database/sql" "fmt" "reflect" + + "go.xyrillian.de/oblast/internal" ) func Select[T any](db *DB, query string, args ...any) ([]T, error) { @@ -17,27 +20,12 @@ func Select[T any](db *DB, query string, args ...any) ([]T, error) { if err != nil { return nil, err } - rows, err := db.Query(query, args...) + rows, indexes, err := db.startQuery(plan, query, args...) if err != nil { return nil, err } defer rows.Close() - columnNames, err := rows.Columns() - if err != nil { - return nil, err - } - indexes := make([][]int, len(columnNames)) - for idx, columnName := range columnNames { - var ok bool - indexes[idx], ok = plan.IndexByColumnName[columnName] - if !ok { - var zero T - return nil, fmt.Errorf("result has column %q in position %d, but no field in %T has `db:%[1]q`", - columnName, idx, zero) - } - } - var result []T slots := make([]any, len(indexes)) for rows.Next() { @@ -55,3 +43,36 @@ func Select[T any](db *DB, query string, args ...any) ([]T, error) { return result, nil } + +func (db *DB) startQuery(plan internal.Plan, query string, args ...any) (rows *sql.Rows, indexes [][]int, err error) { + rows, err = db.Query(query, args...) + if err != nil { + return nil, nil, err + } + defer func() { + if err != nil { + closeErr := rows.Close() + if closeErr != nil { + err = fmt.Errorf("%w (additional error during rows.Close: %s)", err, closeErr.Error()) + } + } + }() + + columnNames, err := rows.Columns() + if err != nil { + return nil, nil, err + } + indexes = make([][]int, len(columnNames)) + for idx, columnName := range columnNames { + var ok bool + indexes[idx], ok = plan.IndexByColumnName[columnName] + if !ok { + return nil, nil, fmt.Errorf( + "result has column %q in position %d, but no field in record type has `db:%[1]q`", + columnName, idx, + ) + } + } + + return rows, indexes, nil +} -- cgit v1.2.3