diff options
Diffstat (limited to 'query.go')
| -rw-r--r-- | query.go | 46 |
1 files changed, 32 insertions, 14 deletions
@@ -11,10 +11,9 @@ import ( "go.xyrillian.de/oblast/internal" ) -func Select[T any](db *DB, query string, args ...any) ([]T, error) { - // TODO: minimize function body to avoid binary size blowup from monomorphization - // TODO: catch error from rows.Close(), if any - // TODO: add context to errors +func Select[T any](db *DB, query string, args ...any) (result []T, returnedError error) { + // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. + // Any expression that does not depend on type T should be factored out into a reusable function. plan, err := db.getPlan(reflect.TypeFor[T]()) if err != nil { @@ -24,17 +23,14 @@ func Select[T any](db *DB, query string, args ...any) ([]T, error) { if err != nil { return nil, err } - defer rows.Close() + defer func() { + returnedError = mergeRowsCloseError(returnedError, rows.Close()) + }() - var result []T slots := make([]any, len(indexes)) for rows.Next() { var target T - rvalue := reflect.ValueOf(&target).Elem() - for idx, index := range indexes { - slots[idx] = rvalue.FieldByIndex(index).Addr().Interface() - } - err := rows.Scan(slots...) + err = db.collectRow(rows, reflect.ValueOf(&target).Elem(), slots, indexes) if err != nil { return nil, err } @@ -47,20 +43,20 @@ func Select[T any](db *DB, query string, args ...any) ([]T, error) { func (db *DB) startQuery(plan internal.Plan, query string, args ...any) (rows *sql.Rows, indexes [][]int, err error) { rows, err = db.Query(query, args...) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("during Query(): %w", err) } defer func() { if err != nil { closeErr := rows.Close() if closeErr != nil { - err = fmt.Errorf("%w (additional error during rows.Close: %s)", err, closeErr.Error()) + err = fmt.Errorf("%w (additional error during rows.Close(): %s)", err, closeErr.Error()) } } }() columnNames, err := rows.Columns() if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("during rows.Columns(): %w", err) } indexes = make([][]int, len(columnNames)) for idx, columnName := range columnNames { @@ -76,3 +72,25 @@ func (db *DB) startQuery(plan internal.Plan, query string, args ...any) (rows *s return rows, indexes, nil } + +func (db *DB) collectRow(rows *sql.Rows, v reflect.Value, slots []any, indexes [][]int) error { + for idx, index := range indexes { + slots[idx] = v.FieldByIndex(index).Addr().Interface() + } + err := rows.Scan(slots...) + if err != nil { + return fmt.Errorf("during rows.Scan(): %w", err) + } + return nil +} + +func mergeRowsCloseError(err, closeErr error) error { + switch { + case closeErr == nil: + return err + case err == nil: + return fmt.Errorf("during rows.Close(): %w", closeErr) + default: + return fmt.Errorf("%w (additional error during rows.Close(): %s)", err, closeErr.Error()) + } +} |
