diff options
Diffstat (limited to 'query.go')
| -rw-r--r-- | query.go | 19 |
1 files changed, 8 insertions, 11 deletions
@@ -11,15 +11,12 @@ import ( "go.xyrillian.de/oblast/internal" ) -func Select[T any](db *DB, query string, args ...any) (result []T, returnedError error) { +// TODO: allow taking *sql.Tx in addition to *sql.DB +func (s Store[R]) Select(db *sql.DB, query string, args ...any) (result []R, returnedError error) { // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. - // Any expression that does not depend on type T should be factored out into a reusable function. + // Any expression that does not depend on type R should be factored out into a reusable function. - plan, err := db.getPlan(reflect.TypeFor[T]()) - if err != nil { - return nil, err - } - rows, indexes, err := db.startQuery(plan, query, args...) + rows, indexes, err := startQuery(db, s.plan, query, args...) if err != nil { return nil, err } @@ -29,8 +26,8 @@ func Select[T any](db *DB, query string, args ...any) (result []T, returnedError slots := make([]any, len(indexes)) for rows.Next() { - var target T - err = db.collectRow(rows, reflect.ValueOf(&target).Elem(), slots, indexes) + var target R + err = collectRow(rows, reflect.ValueOf(&target).Elem(), slots, indexes) if err != nil { return nil, err } @@ -40,7 +37,7 @@ func Select[T any](db *DB, query string, args ...any) (result []T, returnedError return result, nil } -func (db *DB) startQuery(plan internal.Plan, query string, args ...any) (rows *sql.Rows, indexes [][]int, err error) { +func startQuery(db *sql.DB, plan internal.Plan, query string, args ...any) (rows *sql.Rows, indexes [][]int, err error) { rows, err = db.Query(query, args...) if err != nil { return nil, nil, fmt.Errorf("during Query(): %w", err) @@ -73,7 +70,7 @@ func (db *DB) startQuery(plan internal.Plan, query string, args ...any) (rows *s return rows, indexes, nil } -func (db *DB) collectRow(rows *sql.Rows, v reflect.Value, slots []any, indexes [][]int) error { +func collectRow(rows *sql.Rows, v reflect.Value, slots []any, indexes [][]int) error { for idx, index := range indexes { slots[idx] = v.FieldByIndex(index).Addr().Interface() } |
