aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStefan Majewsky <majewsky@gmx.net>2026-04-11 00:32:22 +0200
committerStefan Majewsky <majewsky@gmx.net>2026-04-11 00:32:22 +0200
commit04965ab5e7f2490bcb09b2b3242de4a46f2b1043 (patch)
treec05a1f383495efe2040d33667e52bbca0dba6fd9
parent30984e9e686cc1dc2bede6784a2534371b02219d (diff)
downloadgo-oblast-04965ab5e7f2490bcb09b2b3242de4a46f2b1043.tar.gz
reduce function body of Select()
-rw-r--r--internal/plan.go11
-rw-r--r--query.go53
2 files changed, 46 insertions, 18 deletions
diff --git a/internal/plan.go b/internal/plan.go
index 8fc24d8..f738278 100644
--- a/internal/plan.go
+++ b/internal/plan.go
@@ -46,9 +46,16 @@ var (
// BuildPlan creates a new plan for the given struct type.
func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) {
+ p, err := buildPlan(t, dialect)
+ if err != nil {
+ return Plan{}, fmt.Errorf("cannot use type %s.%s for queries: %w", t.PkgPath(), t.Name(), err)
+ }
+ return p, nil
+}
+
+func buildPlan(t reflect.Type, dialect Dialect) (Plan, error) {
if t.Kind() != reflect.Struct {
- return Plan{}, fmt.Errorf("expected record type to be a struct, but got kind %s (full type: %s.%s)",
- t.Kind(), t.PkgPath(), t.Name())
+ return Plan{}, fmt.Errorf("expected struct type, but got kind %s", t.Kind().String())
}
var p = Plan{
diff --git a/query.go b/query.go
index 95c1f59..03b87b6 100644
--- a/query.go
+++ b/query.go
@@ -4,8 +4,11 @@
package oblast
import (
+ "database/sql"
"fmt"
"reflect"
+
+ "go.xyrillian.de/oblast/internal"
)
func Select[T any](db *DB, query string, args ...any) ([]T, error) {
@@ -17,27 +20,12 @@ func Select[T any](db *DB, query string, args ...any) ([]T, error) {
if err != nil {
return nil, err
}
- rows, err := db.Query(query, args...)
+ rows, indexes, err := db.startQuery(plan, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
- columnNames, err := rows.Columns()
- if err != nil {
- return nil, err
- }
- indexes := make([][]int, len(columnNames))
- for idx, columnName := range columnNames {
- var ok bool
- indexes[idx], ok = plan.IndexByColumnName[columnName]
- if !ok {
- var zero T
- return nil, fmt.Errorf("result has column %q in position %d, but no field in %T has `db:%[1]q`",
- columnName, idx, zero)
- }
- }
-
var result []T
slots := make([]any, len(indexes))
for rows.Next() {
@@ -55,3 +43,36 @@ func Select[T any](db *DB, query string, args ...any) ([]T, error) {
return result, nil
}
+
+func (db *DB) startQuery(plan internal.Plan, query string, args ...any) (rows *sql.Rows, indexes [][]int, err error) {
+ rows, err = db.Query(query, args...)
+ if err != nil {
+ return nil, nil, err
+ }
+ defer func() {
+ if err != nil {
+ closeErr := rows.Close()
+ if closeErr != nil {
+ err = fmt.Errorf("%w (additional error during rows.Close: %s)", err, closeErr.Error())
+ }
+ }
+ }()
+
+ columnNames, err := rows.Columns()
+ if err != nil {
+ return nil, nil, err
+ }
+ indexes = make([][]int, len(columnNames))
+ for idx, columnName := range columnNames {
+ var ok bool
+ indexes[idx], ok = plan.IndexByColumnName[columnName]
+ if !ok {
+ return nil, nil, fmt.Errorf(
+ "result has column %q in position %d, but no field in record type has `db:%[1]q`",
+ columnName, idx,
+ )
+ }
+ }
+
+ return rows, indexes, nil
+}