aboutsummaryrefslogtreecommitdiff
path: root/select.go
blob: 8aed2490585e3a8dcfba25a716d340921c1c2c33 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
// SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net>
// SPDX-License-Identifier: Apache-2.0

package oblast

import (
	"database/sql"
	"errors"
	"fmt"
	"reflect"
)

// Select executes the provided SQL query and fills an instance of the record type R for each row in the result set,
// according to the column names reported by the database as part of the result set.
//
// An error is returned if any column name in the result set does not correspond to an addressable field in R.
func (s Store[R]) Select(db Handle, query string, args ...any) (result []R, returnedError error) {
	// NOTE: This function body should be as short as possible to reduce the binary size after monomorphization.
	//       Any expression that does not depend on type R should be factored out into a reusable function.

	rows, indexes, err := startSelectQuery(db, s.plan, query, args...)
	if err != nil {
		return nil, err
	}
	defer func() {
		returnedError = mergeCloseError("Rows", returnedError, rows.Close())
	}()

	slots := make([]any, len(indexes))
	for rows.Next() {
		var target R
		err = collectRow(rows, reflect.ValueOf(&target).Elem(), slots, indexes)
		if err != nil {
			return nil, err
		}
		result = append(result, target)
	}

	return result, nil
}

// SelectWhere is like [Store.Select], but you only provide the part of the SELECT query that comes after the WHERE.
// The initial part ("SELECT ... FROM ... WHERE") is autogenerated and prepended to partialQuery.
// This has two benefits:
//   - It is more efficient because the strategy for loading result rows into the record type R has already been precomputed during [NewStore],
//     whereas a regular [Store.Select] must inspect the column names in the result set for each [Store.Select] call.
//   - For record types that contain only some of the columns of the corresponding database table,
//     the autogenerated SELECT query will only load exactly the necessary fields and nothing else.
//
// partialQuery is implied to start right after the WHERE keyword, which is added automatically.
// To select all records unconditionally, provide a partialQuery of "TRUE", leading to a full query of "SELECT ... FROM ... WHERE TRUE".
// Besides a condition for the WHERE clause, it may contain additional clauses, such as ORDER BY or LIMIT.
//
// Returns an error if [NewStore] was called without the [TableNameIs] option, which is required to generate a query for this method.
func (s Store[R]) SelectWhere(db Handle, partialQuery string, args ...any) (result []R, returnedError error) {
	// NOTE: This function body should be as short as possible to reduce the binary size after monomorphization.
	//       Any expression that does not depend on type R should be factored out into a reusable function.

	rows, indexes, err := startSelectWhereQuery(db, s.plan, partialQuery, args...)
	if err != nil {
		return nil, err
	}
	defer func() {
		returnedError = mergeCloseError("Rows", returnedError, rows.Close())
	}()

	slots := make([]any, len(indexes))
	for rows.Next() {
		var target R
		err = collectRow(rows, reflect.ValueOf(&target).Elem(), slots, indexes)
		if err != nil {
			return nil, err
		}
		result = append(result, target)
	}

	return result, nil
}

func startSelectQuery(db Handle, plan plan, query string, args ...any) (returnedRows *sql.Rows, indexes [][]int, returnedError error) {
	rows, err := db.Query(query, args...)
	if err != nil {
		return nil, nil, fmt.Errorf("during Query(): %w", err)
	}
	defer func() {
		if returnedError != nil {
			closeErr := rows.Close() // NOTE: Not `returnedRows.Close()`! We may have `rows != nil && returnedRows == nil`.
			if closeErr != nil {
				returnedError = fmt.Errorf("%w (additional error during rows.Close(): %s)", returnedError, closeErr.Error())
			}
		}
	}()

	columnNames, err := rows.Columns()
	if err != nil {
		return nil, nil, fmt.Errorf("during rows.Columns(): %w", err)
	}
	indexes = make([][]int, len(columnNames))
	for idx, columnName := range columnNames {
		var ok bool
		indexes[idx], ok = plan.IndexByColumnName[columnName]
		if !ok {
			return nil, nil, fmt.Errorf(
				"result has column %q in position %d, but no field in type %s has `db:%[1]q`",
				columnName, idx, plan.TypeName,
			)
		}
	}

	return rows, indexes, nil
}

func startSelectWhereQuery(db Handle, plan plan, partialQuery string, args ...any) (rows *sql.Rows, indexes [][]int, err error) {
	if plan.Select.Query == "" {
		return nil, nil, errors.New("cannot execute SelectWhere() because query could not be autogenerated")
	}
	query := plan.Select.Query + partialQuery
	rows, err = db.Query(query, args...)
	return rows, plan.Select.ScanIndexes, err
}

func collectRow(rows *sql.Rows, v reflect.Value, slots []any, indexes [][]int) error {
	for idx, index := range indexes {
		slots[idx] = v.FieldByIndex(index).Addr().Interface()
	}
	return rows.Scan(slots...)
}

func mergeCloseError(typeName string, err, closeErr error) error {
	switch {
	case closeErr == nil:
		return err
	case err == nil:
		return fmt.Errorf("during %s.Close(): %w", typeName, closeErr)
	default:
		return fmt.Errorf("%w (additional error during %s.Close(): %s)", err, typeName, closeErr.Error())
	}
}

// SelectOne executes the provided SQL query and fills an instance of the record type R if there is exactly one row in the result set,
// according to the column names reported by the database as part of the result set.
//
// If there are no rows in the result set, [sql.ErrNoRows] is returned.
//
// Warning: Because of limitations in the interface of database/sql, this function is built on [Store.Select] and cannot be any faster than it.
// For maximum performance, use [Store.SelectOneWhere] which avoids the overhead of potentially having to read multiple rows.
func (s Store[R]) SelectOne(db Handle, query string, args ...any) (result R, err error) {
	// NOTE: This function body should be as short as possible to reduce the binary size after monomorphization.
	//       Any expression that does not depend on type R should be factored out into a reusable function.

	var results []R
	results, err = s.Select(db, query, args...)
	if err == nil {
		if len(results) == 0 {
			err = sql.ErrNoRows
		} else {
			result = results[0]
		}
	}
	return
}

// SelectOneWhere is like [Store.SelectOne], but you only provide the part of the SELECT query that comes after the WHERE.
// See [Store.SelectWhere] for an explanation of how the full query is constructed from this partial query.
//
// This method is significantly more efficient than [Store.SelectOne].
// Prefer using it instaed of [Store.SelectOne] whenever possible.
func (s Store[R]) SelectOneWhere(db Handle, partialQuery string, args ...any) (result R, err error) {
	// NOTE: This function body should be as short as possible to reduce the binary size after monomorphization.
	//       Any expression that does not depend on type R should be factored out into a reusable function.

	err = selectOneWhere(db, s.plan, reflect.ValueOf(&result).Elem(), partialQuery, args)
	return
}

func selectOneWhere(db Handle, plan plan, v reflect.Value, partialQuery string, args []any) error {
	if plan.Select.Query == "" {
		return errors.New("cannot execute SelectOneWhere() because query could not be autogenerated")
	}
	query := plan.Select.Query + partialQuery
	slots := make([]any, len(plan.Select.ScanIndexes))
	for idx, index := range plan.Select.ScanIndexes {
		slots[idx] = v.FieldByIndex(index).Addr().Interface()
	}
	return db.QueryRow(query, args...).Scan(slots...)
}