From 1a187cb04b3130572a5b3f7513c1e55b0a59fdc2 Mon Sep 17 00:00:00 2001 From: Stefan Majewsky Date: Mon, 13 Apr 2026 11:39:36 +0200 Subject: reduce code duplication in benchmark tests --- benchmark/benchmark_test.go | 134 ++++++++++---------------------------------- internal/must/must.go | 24 ++++++++ 2 files changed, 54 insertions(+), 104 deletions(-) create mode 100644 internal/must/must.go diff --git a/benchmark/benchmark_test.go b/benchmark/benchmark_test.go index d7399b6..f026ca6 100644 --- a/benchmark/benchmark_test.go +++ b/benchmark/benchmark_test.go @@ -14,40 +14,26 @@ import ( _ "github.com/mattn/go-sqlite3" "go.xyrillian.de/oblast" "go.xyrillian.de/oblast/internal/assert" + "go.xyrillian.de/oblast/internal/must" ) const totalRecordCountForSelect = 10000 -func makeTestDB(t testing.TB, recordCount int) (*sql.DB, error) { - db, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())) - if err != nil { - return nil, err - } - _, err = db.Exec(`CREATE TABLE entries (id INTEGER, message TEXT, PRIMARY KEY (id AUTOINCREMENT))`) - if err != nil { - return nil, err - } +func makeTestDB(t testing.TB, recordCount int) *sql.DB { + db := must.Return(sql.Open("sqlite3", fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())))(t) + _ = must.Return(db.Exec(`CREATE TABLE entries (id INTEGER, message TEXT, PRIMARY KEY (id AUTOINCREMENT))`))(t) if recordCount > 0 { // fill in some random-looking, but deterministic data - stmt, err := db.Prepare(`INSERT INTO entries (id, message) VALUES (?, ?)`) - if err != nil { - return nil, err - } + stmt := must.Return(db.Prepare(`INSERT INTO entries (id, message) VALUES (?, ?)`))(t) for idx := range recordCount { buf := sha256.Sum256([]byte(strconv.Itoa(idx))) - _, err = stmt.Exec(idx, fmt.Sprintf("sha256:%x", buf[:])) - if err != nil { - return nil, err - } - } - err = stmt.Close() - if err != nil { - return nil, err + _ = must.Return(stmt.Exec(idx, fmt.Sprintf("sha256:%x", buf[:])))(t) } + must.Succeed(t, stmt.Close()) } - return db, nil + return db } type OblastEntry struct { @@ -61,10 +47,7 @@ type GorpEntry struct { } func BenchmarkSelectMany(b *testing.B) { - db, err := makeTestDB(b, totalRecordCountForSelect) - if err != nil { - b.Fatal(err) - } + db := makeTestDB(b, totalRecordCountForSelect) // test with different sizes of resultsets (N=1 is an OLTP-like workload, // then the larger N lean more towards the OLAP side of things) @@ -81,56 +64,38 @@ func BenchmarkSelectMany(b *testing.B) { } gdb := gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}} partialQuery := `id < ` + strconv.Itoa(selectedRecordCount) - query := `SELECT * FROM entries WHERE ` + partialQuery //nolint:gosec + query := `SELECT * FROM entries WHERE ` + partialQuery selectWithOblast := func(b *testing.B) { - records, err := store.Select(db, query) - if err != nil { - b.Error(err) - } + records := must.Return(store.Select(db, query))(b) assert.Equal(b, len(records), selectedRecordCount) } selectWithOblastWhere := func(b *testing.B) { - records, err := store.SelectWhere(db, partialQuery) - if err != nil { - b.Error(err) - } + records := must.Return(store.SelectWhere(db, partialQuery))(b) assert.Equal(b, len(records), selectedRecordCount) } selectWithGorp := func(b *testing.B) { var records []GorpEntry - _, err := gdb.Select(&records, query) - if err != nil { - b.Error(err) - } + _ = must.Return(gdb.Select(&records, query))(b) assert.Equal(b, len(records), selectedRecordCount) } selectWithSqlite := func(b *testing.B) { var count int - rows, err := db.Query(query) //nolint:rowserrcheck // false positive - if err != nil { - b.Error(err) - } + rows := must.Return(db.Query(query))(b) //nolint:rowserrcheck // false positive var ( id int64 message string ) for rows.Next() { - err := rows.Scan(&id, &message) - if err != nil { - b.Error(err) - } + must.Succeed(b, rows.Scan(&id, &message)) if id != 20000 && message != "" { // always true; ensures that values are not optimized away count++ } } - err = rows.Close() - if err != nil { - b.Error(err) - } + must.Succeed(b, rows.Close()) assert.Equal(b, count, selectedRecordCount) } @@ -167,10 +132,7 @@ func BenchmarkSelectMany(b *testing.B) { } func BenchmarkSelectOne(b *testing.B) { - db, err := makeTestDB(b, totalRecordCountForSelect) - if err != nil { - b.Fatal(err) - } + db := makeTestDB(b, totalRecordCountForSelect) // grab a "random" record from the DB, not just the first or the last recordID := min(totalRecordCountForSelect*2/3, totalRecordCountForSelect) @@ -189,27 +151,18 @@ func BenchmarkSelectOne(b *testing.B) { query := `SELECT * FROM entries WHERE ` + partialQuery selectWithOblast := func(b *testing.B) { - r, err := store.SelectOne(db, query) - if err != nil { - b.Error(err) - } + r := must.Return(store.SelectOne(db, query))(b) assert.Equal(b, r.ID, recordID) } selectWithOblastWhere := func(b *testing.B) { - r, err := store.SelectOneWhere(db, partialQuery) - if err != nil { - b.Error(err) - } + r := must.Return(store.SelectOneWhere(db, partialQuery))(b) assert.Equal(b, r.ID, recordID) } selectWithGorp := func(b *testing.B) { var r GorpEntry - err := gdb.SelectOne(&r, query) - if err != nil { - b.Error(err) - } + must.Succeed(b, gdb.SelectOne(&r, query)) assert.Equal(b, r.ID, recordID) } @@ -218,10 +171,7 @@ func BenchmarkSelectOne(b *testing.B) { id int64 message string ) - err := db.QueryRow(query).Scan(&id, &message) - if err != nil { - b.Error(err) - } + must.Succeed(b, db.QueryRow(query).Scan(&id, &message)) assert.Equal(b, id, int64(recordID)) } @@ -256,10 +206,7 @@ func BenchmarkSelectOne(b *testing.B) { } func BenchmarkInsertAndDeleteOne(b *testing.B) { - db, err := makeTestDB(b, 0) - if err != nil { - b.Fatal(err) - } + db := makeTestDB(b, 0) // prepare the functions that will be benched store, err := oblast.NewStore[OblastEntry]( @@ -275,45 +222,24 @@ func BenchmarkInsertAndDeleteOne(b *testing.B) { insertAndDeleteWithOblast := func(b *testing.B) { record := OblastEntry{Message: "hello"} - err := store.Insert(db, &record) - if err != nil { - b.Error(err) - } + must.Succeed(b, store.Insert(db, &record)) if record.ID == 0 { b.Errorf("ID was not filled!") } - err = store.Delete(db, record) - if err != nil { - b.Error(err) - } + must.Succeed(b, store.Delete(db, record)) } insertAndDeleteWithGorp := func(b *testing.B) { record := GorpEntry{Message: "hello"} - err := gdb.Insert(&record) - if err != nil { - b.Error(err) - } + must.Succeed(b, gdb.Insert(&record)) if record.ID == 0 { b.Errorf("ID was not filled!") } - _, err = gdb.Delete(&record) - if err != nil { - b.Error(err) - } + _ = must.Return(gdb.Delete(&record))(b) } insertAndDeleteWithSqlite := func(b *testing.B) { - result, err := db.Exec(`INSERT INTO entries (message) VALUES (?)`, "hello") - if err != nil { - b.Error(err) - } - id, err := result.LastInsertId() - if err != nil { - b.Error(err) - } - _, err = db.Exec(`DELETE FROM entries WHERE id = ?`, id) - if err != nil { - b.Error(err) - } + result := must.Return(db.Exec(`INSERT INTO entries (message) VALUES (?)`, "hello"))(b) + id := must.Return(result.LastInsertId())(b) + _ = must.Return(db.Exec(`DELETE FROM entries WHERE id = ?`, id))(b) } // run once to prewarm caches @@ -331,7 +257,7 @@ func BenchmarkInsertAndDeleteOne(b *testing.B) { insertAndDeleteWithOblast(b) } }) - b.Run("via SQLite", func(b *testing.B) { + b.Run("just SQLite", func(b *testing.B) { for range b.N { insertAndDeleteWithSqlite(b) } diff --git a/internal/must/must.go b/internal/must/must.go new file mode 100644 index 0000000..e472579 --- /dev/null +++ b/internal/must/must.go @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky +// SPDX-License-Identifier: Apache-2.0 + +package must + +import "testing" + +// Succeed fails the test if err is not nil. +func Succeed(t testing.TB, err error) { + if err != nil { + t.Fatal(err.Error()) + } +} + +// Return wraps a function returning two output values, +// and either forwards the result value on success, or fails the test on error. +func Return[V any](value V, err error) func(testing.TB) V { + return func(t testing.TB) V { + if err != nil { + t.Fatal(err.Error()) + } + return value + } +} -- cgit v1.2.3