From 5e30087db4a06c24c103737d4cb7dcdf06da5b24 Mon Sep 17 00:00:00 2001 From: Stefan Majewsky Date: Sun, 12 Apr 2026 17:18:43 +0200 Subject: add Store.SelectOne --- .golangci.yaml | 8 +++ benchmark/benchmark_test.go | 115 ++++++++++++++++++++++++++++++++++------- internal/assert/assert.go | 4 +- oblast.go | 4 ++ query.go | 92 --------------------------------- select.go | 122 ++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 233 insertions(+), 112 deletions(-) delete mode 100644 query.go create mode 100644 select.go diff --git a/.golangci.yaml b/.golangci.yaml index 995ee5a..286a9a5 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -32,6 +32,8 @@ linters: - prealloc - predeclared - reassign + - revive + - rowserrcheck - unconvert - usestdlibvars - usetesting @@ -71,6 +73,12 @@ linters: go-version-pattern: 1\.\d+(\.0)?$ nolintlint: require-specific: true + revive: + rules: + - name: exported + arguments: + - checkPrivateReceivers + - disableChecksOnConstants staticcheck: dot-import-whitelist: - github.com/majewsky/gg/option diff --git a/benchmark/benchmark_test.go b/benchmark/benchmark_test.go index e0822f2..b60951a 100644 --- a/benchmark/benchmark_test.go +++ b/benchmark/benchmark_test.go @@ -13,40 +13,50 @@ import ( "github.com/go-gorp/gorp/v3" _ "github.com/mattn/go-sqlite3" "go.xyrillian.de/oblast" + "go.xyrillian.de/oblast/internal/assert" ) -func BenchmarkSelect(b *testing.B) { - const totalRecordCount = 1000 +const totalRecordCount = 1000 - db, err := sql.Open("sqlite3", "file:foobar?mode=memory&cache=shared") +func makeTestDB(t testing.TB) (*sql.DB, error) { + db, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())) if err != nil { - b.Fatal(err) + return nil, err } // fill in some random-looking, but deterministic data _, err = db.Exec(`CREATE TABLE entries (id INTEGER, message TEXT)`) if err != nil { - b.Fatal(err) + return nil, err } stmt, err := db.Prepare(`INSERT INTO entries (id, message) VALUES (?, ?)`) if err != nil { - b.Fatal(err) + return nil, err } for idx := range totalRecordCount { buf := sha256.Sum256([]byte(strconv.Itoa(idx))) _, err = stmt.Exec(idx, fmt.Sprintf("sha256:%x", buf[:])) if err != nil { - b.Fatal(err) + return nil, err } } err = stmt.Close() + if err != nil { + return nil, err + } + + return db, nil +} + +func BenchmarkSelect(b *testing.B) { + db, err := makeTestDB(b) if err != nil { b.Fatal(err) } // test with different sizes of resultsets (N=1 is an OLTP-like workload, // then the larger N lean more towards the OLAP side of things) - for _, selectedRecordCount := range []int{1, 10, 100, 1000} { + for selectedRecordCount := 1; selectedRecordCount < totalRecordCount; selectedRecordCount *= 10 { b.Run("N="+strconv.Itoa(selectedRecordCount), func(b *testing.B) { // prepare the functions that will be benched type record struct { @@ -65,9 +75,7 @@ func BenchmarkSelect(b *testing.B) { if err != nil { b.Error(err) } - if len(records) != selectedRecordCount { - b.Errorf("expected %d, but got %d records", selectedRecordCount, len(records)) - } + assert.Equal(b, len(records), selectedRecordCount) } selectWithGorp := func(b *testing.B) { @@ -76,14 +84,12 @@ func BenchmarkSelect(b *testing.B) { if err != nil { b.Error(err) } - if len(records) != selectedRecordCount { - b.Errorf("expected %d, but got %d records", selectedRecordCount, len(records)) - } + assert.Equal(b, len(records), selectedRecordCount) } selectWithSqlite := func(b *testing.B) { var count int - rows, err := db.Query(query) + rows, err := db.Query(query) //nolint:rowserrcheck // false positive if err != nil { b.Error(err) } @@ -104,9 +110,7 @@ func BenchmarkSelect(b *testing.B) { if err != nil { b.Error(err) } - if count != selectedRecordCount { - b.Errorf("expected %d, but got %d records", selectedRecordCount, count) - } + assert.Equal(b, count, selectedRecordCount) } // run once to prewarm caches @@ -135,3 +139,78 @@ func BenchmarkSelect(b *testing.B) { }) } } + +func BenchmarkSelectOne(b *testing.B) { + db, err := makeTestDB(b) + if err != nil { + b.Fatal(err) + } + + // grab a "random" record from the DB, not just the first or the last + recordID := min(totalRecordCount*2/3, totalRecordCount) + + // prepare the functions that will be benched + type record struct { + ID int `db:"id"` + Message string `db:"message"` + } + store, err := oblast.NewStore[record](oblast.SqliteDialect()) + if err != nil { + b.Fatal(err) + } + gdb := gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}} + query := `SELECT * FROM entries WHERE id = ` + strconv.Itoa(recordID) + + selectWithOblast := func(b *testing.B) { + r, err := store.SelectOne(db, query) + if err != nil { + b.Error(err) + } + assert.Equal(b, r.ID, recordID) + } + + selectWithGorp := func(b *testing.B) { + var r record + err := gdb.SelectOne(&r, query) + if err != nil { + b.Error(err) + } + assert.Equal(b, r.ID, recordID) + } + + selectWithSqlite := func(b *testing.B) { + var ( + id int64 + message string + ) + err := db.QueryRow(query).Scan(&id, &message) + if err != nil { + b.Error(err) + } + assert.Equal(b, id, int64(recordID)) + } + + // run once to prewarm caches + selectWithOblast(b) + selectWithGorp(b) + if b.Failed() { + b.FailNow() + } + + // run actual benchmark + b.Run("via Gorp", func(b *testing.B) { + for range b.N { + selectWithGorp(b) + } + }) + b.Run("via Oblast", func(b *testing.B) { + for range b.N { + selectWithOblast(b) + } + }) + b.Run("just SQLite", func(b *testing.B) { + for range b.N { + selectWithSqlite(b) + } + }) +} diff --git a/internal/assert/assert.go b/internal/assert/assert.go index c4e7b50..c82d35c 100644 --- a/internal/assert/assert.go +++ b/internal/assert/assert.go @@ -9,7 +9,7 @@ import ( ) // Equal is a test assertion. -func Equal[V comparable](t *testing.T, actual, expected V) { +func Equal[V comparable](t testing.TB, actual, expected V) { t.Helper() if actual != expected { t.Errorf("expected %#v, but got %#v", expected, actual) @@ -17,7 +17,7 @@ func Equal[V comparable](t *testing.T, actual, expected V) { } // DeepEqual is a test assertion. -func DeepEqual[V any](t *testing.T, actual, expected V) { +func DeepEqual[V any](t testing.TB, actual, expected V) { t.Helper() if !reflect.DeepEqual(actual, expected) { t.Errorf("expected %#v, but got %#v", expected, actual) diff --git a/oblast.go b/oblast.go index be1571e..415b6cb 100644 --- a/oblast.go +++ b/oblast.go @@ -42,6 +42,7 @@ package oblast // import "go.xyrillian.de/oblast" import ( "database/sql" + "errors" "go.xyrillian.de/oblast/internal" ) @@ -75,3 +76,6 @@ var ( _ Handle = &sql.DB{} _ Handle = &sql.Tx{} ) + +// ErrMultipleRows is returned by [Store.SelectOne] if the query returned multiple rows. +var ErrMultipleRows = errors.New("sql: multiple rows in result set") diff --git a/query.go b/query.go deleted file mode 100644 index fd80f56..0000000 --- a/query.go +++ /dev/null @@ -1,92 +0,0 @@ -// SPDX-FileCopyrightText: 2026 Stefan Majewsky -// SPDX-License-Identifier: Apache-2.0 - -package oblast - -import ( - "database/sql" - "fmt" - "reflect" - - "go.xyrillian.de/oblast/internal" -) - -func (s Store[R]) Select(db Handle, query string, args ...any) (result []R, returnedError error) { - // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. - // Any expression that does not depend on type R should be factored out into a reusable function. - - rows, indexes, err := startQuery(db, s.plan, query, args...) - if err != nil { - return nil, err - } - defer func() { - returnedError = mergeRowsCloseError(returnedError, rows.Close()) - }() - - slots := make([]any, len(indexes)) - for rows.Next() { - var target R - err = collectRow(rows, reflect.ValueOf(&target).Elem(), slots, indexes) - if err != nil { - return nil, err - } - result = append(result, target) - } - - return result, nil -} - -func startQuery(db Handle, plan internal.Plan, query string, args ...any) (rows *sql.Rows, indexes [][]int, err error) { - rows, err = db.Query(query, args...) - if err != nil { - return nil, nil, fmt.Errorf("during Query(): %w", err) - } - defer func() { - if err != nil { - closeErr := rows.Close() - if closeErr != nil { - err = fmt.Errorf("%w (additional error during rows.Close(): %s)", err, closeErr.Error()) - } - } - }() - - columnNames, err := rows.Columns() - if err != nil { - return nil, nil, fmt.Errorf("during rows.Columns(): %w", err) - } - indexes = make([][]int, len(columnNames)) - for idx, columnName := range columnNames { - var ok bool - indexes[idx], ok = plan.IndexByColumnName[columnName] - if !ok { - return nil, nil, fmt.Errorf( - "result has column %q in position %d, but no field in record type has `db:%[1]q`", - columnName, idx, - ) - } - } - - return rows, indexes, nil -} - -func collectRow(rows *sql.Rows, v reflect.Value, slots []any, indexes [][]int) error { - for idx, index := range indexes { - slots[idx] = v.FieldByIndex(index).Addr().Interface() - } - err := rows.Scan(slots...) - if err != nil { - return fmt.Errorf("during rows.Scan(): %w", err) - } - return nil -} - -func mergeRowsCloseError(err, closeErr error) error { - switch { - case closeErr == nil: - return err - case err == nil: - return fmt.Errorf("during rows.Close(): %w", closeErr) - default: - return fmt.Errorf("%w (additional error during rows.Close(): %s)", err, closeErr.Error()) - } -} diff --git a/select.go b/select.go new file mode 100644 index 0000000..23521ed --- /dev/null +++ b/select.go @@ -0,0 +1,122 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky +// SPDX-License-Identifier: Apache-2.0 + +package oblast + +import ( + "database/sql" + "fmt" + "reflect" + + "go.xyrillian.de/oblast/internal" +) + +// Select executes the provided SQL query and fills an instance of the record type R for each row in the result set, +// according to the column names reported by the database as part of the result set. +// +// An error is returned if any column name in the result set does not correspond to an addressable field in R. +func (s Store[R]) Select(db Handle, query string, args ...any) (result []R, returnedError error) { + // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. + // Any expression that does not depend on type R should be factored out into a reusable function. + + rows, indexes, err := startQuery(db, s.plan, query, args...) + if err != nil { + return nil, err + } + defer func() { + returnedError = mergeRowsCloseError(returnedError, rows.Close()) + }() + + slots := make([]any, len(indexes)) + for rows.Next() { + var target R + err = collectRow(rows, reflect.ValueOf(&target).Elem(), slots, indexes) + if err != nil { + return nil, err + } + result = append(result, target) + } + + return result, nil +} + +func startQuery(db Handle, plan internal.Plan, query string, args ...any) (rows *sql.Rows, indexes [][]int, err error) { + rows, err = db.Query(query, args...) + if err != nil { + return nil, nil, fmt.Errorf("during Query(): %w", err) + } + defer func() { + if err != nil { + closeErr := rows.Close() + if closeErr != nil { + err = fmt.Errorf("%w (additional error during rows.Close(): %s)", err, closeErr.Error()) + } + } + }() + + columnNames, err := rows.Columns() + if err != nil { + return nil, nil, fmt.Errorf("during rows.Columns(): %w", err) + } + indexes = make([][]int, len(columnNames)) + for idx, columnName := range columnNames { + var ok bool + indexes[idx], ok = plan.IndexByColumnName[columnName] + if !ok { + return nil, nil, fmt.Errorf( + "result has column %q in position %d, but no field in record type has `db:%[1]q`", + columnName, idx, + ) + } + } + + return rows, indexes, nil +} + +func collectRow(rows *sql.Rows, v reflect.Value, slots []any, indexes [][]int) error { + for idx, index := range indexes { + slots[idx] = v.FieldByIndex(index).Addr().Interface() + } + err := rows.Scan(slots...) + if err != nil { + return fmt.Errorf("during rows.Scan(): %w", err) + } + return nil +} + +func mergeRowsCloseError(err, closeErr error) error { + switch { + case closeErr == nil: + return err + case err == nil: + return fmt.Errorf("during rows.Close(): %w", closeErr) + default: + return fmt.Errorf("%w (additional error during rows.Close(): %s)", err, closeErr.Error()) + } +} + +// SelectOne executes the provided SQL query and fills an instance of the record type R if there is exactly one row in the result set, +// according to the column names reported by the database as part of the result set. +// +// If there are no rows in the result set, [sql.ErrNoRows] is returned. +// If there are multiple rows in the result set, [ErrMultipleRows] is returned. +// +// Warning: Because of limitations in the interface of database/sql, this function is built on [Store.Select] and cannot be any faster than it. +// For maximum performance, use [Store.SelectOneWhere] which avoids the overhead of potentially having to read multiple rows. +func (s Store[R]) SelectOne(db Handle, query string, args ...any) (result R, err error) { + // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. + // Any expression that does not depend on type R should be factored out into a reusable function. + var results []R + results, err = s.Select(db, query, args...) + if err == nil { + switch len(results) { + case 0: + err = sql.ErrNoRows + case 1: + result = results[0] + default: + err = ErrMultipleRows + } + } + return +} -- cgit v1.2.3