From a561ebb42148c72638f943e44191da07c16df7f6 Mon Sep 17 00:00:00 2001 From: Stefan Majewsky Date: Wed, 13 May 2026 01:11:30 +0200 Subject: return a concrete type from Wrap() to enable non-Oblast DB operations --- CHANGELOG.md | 6 +++ benchmark/benchmark_test.go | 68 ++++++++++++++++----------------- benchmark/internal/oblast_pgx/handle.go | 8 ++-- benchmark/postgres_test.go | 50 +++++++++++------------- handle.go | 57 ++++++++++++++++----------- handle/handle.go | 10 +++-- query.go | 2 +- select.go | 6 +-- 8 files changed, 110 insertions(+), 97 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f0aa71..a260a2d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ SPDX-License-Identifier: Apache-2.0 # v0.8.0 (TBD) +API changes: + +- `Wrap` now returns a struct type `SqlHandle` instead of the interface type `Handle`. + This enables reaching into the `SqlHandle` and getting the original `*sql.DB` and `*sql.Tx` back out, + which is more ergonomic in situations where Oblast loads/stores need to be mixed with other types of DB operations. + Changes: - Insert, Upsert, Update and Delete will no longer panic when one of the fields they need to access is within a pointer-to-struct that is nil. diff --git a/benchmark/benchmark_test.go b/benchmark/benchmark_test.go index da764b2..4bc7950 100644 --- a/benchmark/benchmark_test.go +++ b/benchmark/benchmark_test.go @@ -43,14 +43,14 @@ var ( batchSizesForUpdate = []int{1, 2, 4, 8, 16, 100} ) -func makeSqliteTestDB(t testing.TB, recordCount int) (db *sql.DB, dsn string) { +func makeSqliteTestDB(t testing.TB, recordCount int) (db oblast.SqlHandle[*sql.DB], dsn string) { dsn = fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name()) - db = must.Return(sql.Open("sqlite3", dsn))(t) - _ = must.Return(db.Exec(`CREATE TABLE entries (id INTEGER, message TEXT, PRIMARY KEY (id AUTOINCREMENT))`))(t) + db = oblast.Wrap(must.Return(sql.Open("sqlite3", dsn))(t)) + _ = must.Return(db.Base.Exec(`CREATE TABLE entries (id INTEGER, message TEXT, PRIMARY KEY (id AUTOINCREMENT))`))(t) if recordCount > 0 { // fill in some random-looking, but deterministic data - stmt := must.Return(db.Prepare(`INSERT INTO entries (id, message) VALUES (?, ?)`))(t) + stmt := must.Return(db.Base.Prepare(`INSERT INTO entries (id, message) VALUES (?, ?)`))(t) for idx := range recordCount { buf := sha256.Sum256([]byte(strconv.Itoa(idx))) _ = must.Return(stmt.Exec(idx, fmt.Sprintf("sha256:%x", buf[:])))(t) @@ -80,7 +80,6 @@ func (GormEntry) TableName() string { return "entries" } func BenchmarkORMSelectMany(b *testing.B) { db, dsn := makeSqliteTestDB(b, totalRecordCountForSelect) - dbh := oblast.Wrap(db) // test with different sizes of resultsets (N=1 is an OLTP-like workload, // then the larger N lean more towards the OLAP side of things) @@ -92,19 +91,19 @@ func BenchmarkORMSelectMany(b *testing.B) { oblast.TableNameIs("entries"), oblast.PrimaryKeyIs("id"), ) - gorpDB := gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}} + gorpDB := gorp.DbMap{Db: db.Base, Dialect: gorp.SqliteDialect{}} gormDB := must.Return(gorm.Open(sqlite.Open(dsn), &gorm.Config{}))(b) partialQuery := `id < ` + strconv.Itoa(batchSize) query := `SELECT * FROM entries WHERE ` + partialQuery precomputedQuery := store.MustPrepareSelectQueryWhere(partialQuery) selectWithOblast := func(b *testing.B) { - records := must.Return(store.Select(noctx, dbh, query))(b) + records := must.Return(store.Select(noctx, db, query))(b) assert.Equal(b, len(records), batchSize) } selectWithOblastWhere := func(b *testing.B) { - records := must.Return(precomputedQuery.Select(noctx, dbh))(b) + records := must.Return(precomputedQuery.Select(noctx, db))(b) assert.Equal(b, len(records), batchSize) } @@ -121,7 +120,7 @@ func BenchmarkORMSelectMany(b *testing.B) { selectWithSqlite := func(b *testing.B) { var count int - rows := must.Return(db.Query(query))(b) //nolint:rowserrcheck // false positive + rows := must.Return(db.Base.Query(query))(b) //nolint:rowserrcheck // false positive var ( id int64 message string @@ -176,7 +175,6 @@ func BenchmarkORMSelectMany(b *testing.B) { func BenchmarkORMSelectOne(b *testing.B) { db, dsn := makeSqliteTestDB(b, totalRecordCountForSelect) - dbh := oblast.Wrap(db) // grab a "random" record from the DB, not just the first or the last recordID := min(totalRecordCountForSelect*2/3, totalRecordCountForSelect) @@ -187,19 +185,19 @@ func BenchmarkORMSelectOne(b *testing.B) { oblast.TableNameIs("entries"), oblast.PrimaryKeyIs("id"), ) - gorpDB := gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}} + gorpDB := gorp.DbMap{Db: db.Base, Dialect: gorp.SqliteDialect{}} gormDB := must.Return(gorm.Open(sqlite.Open(dsn), &gorm.Config{}))(b) partialQuery := `id = ` + strconv.Itoa(recordID) query := `SELECT * FROM entries WHERE ` + partialQuery precomputedQuery := store.MustPrepareSelectQueryWhere(partialQuery) selectWithOblast := func(b *testing.B) { - r := must.Return(store.SelectOne(noctx, dbh, query))(b) + r := must.Return(store.SelectOne(noctx, db, query))(b) assert.Equal(b, r.ID, recordID) } selectWithOblastWhere := func(b *testing.B) { - r := must.Return(precomputedQuery.SelectOne(noctx, dbh))(b) + r := must.Return(precomputedQuery.SelectOne(noctx, db))(b) assert.Equal(b, r.ID, recordID) } @@ -219,7 +217,7 @@ func BenchmarkORMSelectOne(b *testing.B) { id int64 message string ) - must.Succeed(b, db.QueryRow(query).Scan(&id, &message)) + must.Succeed(b, db.Base.QueryRow(query).Scan(&id, &message)) assert.Equal(b, id, int64(recordID)) } @@ -261,14 +259,13 @@ func BenchmarkORMSelectOne(b *testing.B) { func BenchmarkORMInsertAndDelete(b *testing.B) { db, dsn := makeSqliteTestDB(b, 0) - dbh := oblast.Wrap(db) store := oblast.MustNewStore[OblastEntry]( oblast.SqliteDialect(), oblast.TableNameIs("entries"), oblast.PrimaryKeyIs("id"), ) - gorpDB := gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}} + gorpDB := gorp.DbMap{Db: db.Base, Dialect: gorp.SqliteDialect{}} gorpDB.AddTableWithName(GorpEntry{}, "entries").SetKeys(true, "id") gormDB := must.Return(gorm.Open(sqlite.Open(dsn), &gorm.Config{}))(b) @@ -283,22 +280,22 @@ func BenchmarkORMInsertAndDelete(b *testing.B) { records[idx] = OblastEntry{Message: "hello"} recordsForInsert[idx] = &records[idx] } - must.Succeed(b, store.Insert(noctx, dbh, recordsForInsert...)) + must.Succeed(b, store.Insert(noctx, db, recordsForInsert...)) for _, r := range records { if r.ID == 0 { b.Errorf("ID was not filled!") } } - must.Succeed(b, store.Delete(noctx, dbh, records...)) + must.Succeed(b, store.Delete(noctx, db, records...)) } if batchSize == 1 { insertAndDeleteWithOblast = func(b *testing.B) { record := OblastEntry{Message: "hello"} - must.Succeed(b, store.Insert(noctx, dbh, &record)) + must.Succeed(b, store.Insert(noctx, db, &record)) if record.ID == 0 { b.Errorf("ID was not filled!") } - must.Succeed(b, store.Delete(noctx, dbh, record)) + must.Succeed(b, store.Delete(noctx, db, record)) } } @@ -354,23 +351,23 @@ func BenchmarkORMInsertAndDelete(b *testing.B) { insertAndDeleteWithStraightExec := func(b *testing.B) { ids := make([]int64, batchSize) for idx := range ids { - result := must.Return(db.Exec(`INSERT INTO entries (message) VALUES (?)`, "hello"))(b) + result := must.Return(db.Base.Exec(`INSERT INTO entries (message) VALUES (?)`, "hello"))(b) ids[idx] = must.Return(result.LastInsertId())(b) } for _, id := range ids { - _ = must.Return(db.Exec(`DELETE FROM entries WHERE id = ?`, id))(b) + _ = must.Return(db.Base.Exec(`DELETE FROM entries WHERE id = ?`, id))(b) } } insertAndDeleteWithPreparedExec := func(b *testing.B) { ids := make([]int64, batchSize) - stmtInsert := must.Return(db.Prepare(`INSERT INTO entries (message) VALUES (?)`))(b) + stmtInsert := must.Return(db.Base.Prepare(`INSERT INTO entries (message) VALUES (?)`))(b) defer stmtInsert.Close() for idx := range ids { result := must.Return(stmtInsert.Exec("hello"))(b) ids[idx] = must.Return(result.LastInsertId())(b) } - stmtDelete := must.Return(db.Prepare(`DELETE FROM entries WHERE id = ?`))(b) + stmtDelete := must.Return(db.Base.Prepare(`DELETE FROM entries WHERE id = ?`))(b) defer stmtDelete.Close() for _, id := range ids { _ = must.Return(stmtDelete.Exec(id))(b) @@ -380,21 +377,21 @@ func BenchmarkORMInsertAndDelete(b *testing.B) { insertAndDeleteWithStraightQueryRow := func(b *testing.B) { ids := make([]int64, batchSize) for idx := range ids { - must.Succeed(b, db.QueryRow(`INSERT INTO entries (message) VALUES (?) RETURNING id`, "hello").Scan(&ids[idx])) + must.Succeed(b, db.Base.QueryRow(`INSERT INTO entries (message) VALUES (?) RETURNING id`, "hello").Scan(&ids[idx])) } for _, id := range ids { - _ = must.Return(db.Exec(`DELETE FROM entries WHERE id = ?`, id))(b) + _ = must.Return(db.Base.Exec(`DELETE FROM entries WHERE id = ?`, id))(b) } } insertAndDeleteWithPreparedQueryRow := func(b *testing.B) { ids := make([]int64, batchSize) - stmtInsert := must.Return(db.Prepare(`INSERT INTO entries (message) VALUES (?) RETURNING id`))(b) + stmtInsert := must.Return(db.Base.Prepare(`INSERT INTO entries (message) VALUES (?) RETURNING id`))(b) defer stmtInsert.Close() for idx := range ids { must.Succeed(b, stmtInsert.QueryRow("hello").Scan(&ids[idx])) } - stmtDelete := must.Return(db.Prepare(`DELETE FROM entries WHERE id = ?`))(b) + stmtDelete := must.Return(db.Base.Prepare(`DELETE FROM entries WHERE id = ?`))(b) defer stmtDelete.Close() for _, id := range ids { _ = must.Return(stmtDelete.Exec(id))(b) @@ -447,14 +444,13 @@ func BenchmarkORMInsertAndDelete(b *testing.B) { func BenchmarkORMUpdate(b *testing.B) { db, dsn := makeSqliteTestDB(b, 0) - dbh := oblast.Wrap(db) store := oblast.MustNewStore[OblastEntry]( oblast.SqliteDialect(), oblast.TableNameIs("entries"), oblast.PrimaryKeyIs("id"), ) - gorpDB := gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}} + gorpDB := gorp.DbMap{Db: db.Base, Dialect: gorp.SqliteDialect{}} gorpDB.AddTableWithName(GorpEntry{}, "entries").SetKeys(true, "id") gormDB := must.Return(gorm.Open(sqlite.Open(dsn), &gorm.Config{}))(b) @@ -462,14 +458,14 @@ func BenchmarkORMUpdate(b *testing.B) { for _, batchSize := range batchSizesForUpdate { b.Run("N="+strconv.Itoa(batchSize), func(b *testing.B) { // prepare a bunch of records that we can update, in a reproducible way - _ = must.Return(db.Exec(`DELETE FROM entries`)) + _ = must.Return(db.Base.Exec(`DELETE FROM entries`)) recordsForOblast := make([]OblastEntry, batchSize) recordsForOblastForInsert := make([]*OblastEntry, batchSize) for idx := range recordsForOblast { recordsForOblast[idx] = OblastEntry{Message: "hello"} recordsForOblastForInsert[idx] = &recordsForOblast[idx] } - must.Succeed(b, store.Insert(noctx, dbh, recordsForOblastForInsert...)) + must.Succeed(b, store.Insert(noctx, db, recordsForOblastForInsert...)) recordsForGorp := make([]any, batchSize) for idx, r := range recordsForOblast { recordsForGorp[idx] = new(GorpEntry(r)) @@ -484,7 +480,7 @@ func BenchmarkORMUpdate(b *testing.B) { for idx := range recordsForOblast { recordsForOblast[idx].Message = message } - must.Succeed(b, store.Update(noctx, dbh, recordsForOblast...)) + must.Succeed(b, store.Update(noctx, db, recordsForOblast...)) } updateWithGorp := func(b *testing.B, message string) { for _, r := range recordsForGorp { @@ -502,11 +498,11 @@ func BenchmarkORMUpdate(b *testing.B) { } updateWithStraightSqlite := func(b *testing.B, message string) { for _, r := range recordsForOblast { - _ = must.Return(db.Exec(`UPDATE entries SET message = ? WHERE id = ?`, message, r.ID))(b) + _ = must.Return(db.Base.Exec(`UPDATE entries SET message = ? WHERE id = ?`, message, r.ID))(b) } } updateWithPreparedSqlite := func(b *testing.B, message string) { - stmt := must.Return(db.Prepare(`UPDATE entries SET message = ? WHERE id = ?`))(b) + stmt := must.Return(db.Base.Prepare(`UPDATE entries SET message = ? WHERE id = ?`))(b) for _, r := range recordsForOblast { _ = must.Return(stmt.Exec(message, r.ID))(b) } @@ -514,7 +510,7 @@ func BenchmarkORMUpdate(b *testing.B) { } checkRecordsUpdated := func(b *testing.B, message string) { var count int64 - must.Succeed(b, db.QueryRow(`SELECT COUNT(*) FROM entries WHERE message = ?`, message).Scan(&count)) + must.Succeed(b, db.Base.QueryRow(`SELECT COUNT(*) FROM entries WHERE message = ?`, message).Scan(&count)) assert.Equal(b, count, int64(batchSize)) } diff --git a/benchmark/internal/oblast_pgx/handle.go b/benchmark/internal/oblast_pgx/handle.go index 6a88e2b..7ccc9ea 100644 --- a/benchmark/internal/oblast_pgx/handle.go +++ b/benchmark/internal/oblast_pgx/handle.go @@ -43,8 +43,8 @@ type wrappedHandle struct { inner Handle } -// Prepare implements the [handle.Handle] interface. -func (h wrappedHandle) Prepare(ctx context.Context, query string, repeated bool) (handle.Statement, error) { +// OblastPrepare implements the [handle.Handle] interface. +func (h wrappedHandle) OblastPrepare(ctx context.Context, query string, repeated bool) (handle.Statement, error) { if !repeated { return wrappedUnpreparedStatement{query, h.inner}, nil } @@ -74,8 +74,8 @@ func deallocate(ctx context.Context, h Handle, stmt *pgconn.StatementDescription } } -// Query implements the [handle.Handle] interface. -func (h wrappedHandle) Query(ctx context.Context, query string, args []any) (handle.Rows, error) { +// OblastQuery implements the [handle.Handle] interface. +func (h wrappedHandle) OblastQuery(ctx context.Context, query string, args []any) (handle.Rows, error) { rows, err := h.inner.Query(ctx, query, args...) return wrappedRows{rows}, err } diff --git a/benchmark/postgres_test.go b/benchmark/postgres_test.go index 320ea2a..02c2c43 100644 --- a/benchmark/postgres_test.go +++ b/benchmark/postgres_test.go @@ -21,7 +21,7 @@ import ( "go.xyrillian.de/oblast/internal/testhelpers/must" ) -// NOTE: In this file, we benchmark different PostgreSQL database drivers against each other with or without Oblast inbetween. +// NOTE: In this file, we benchmark different PostgreSQL database drivers against each other with or without Oblast in between. // All benchmarks are called "BenchmarkPostgres...". // To run these benchmarks, you need to have provide a DSN to a PostgreSQL database in $BENCHMARK_POSTGRES_DSN. @@ -36,14 +36,14 @@ func BenchmarkPostgresHeadingHeadingHeadingHeadingHeadingHeadingHeadingHeading(b const defaultPostgresDSN = "host=localhost user=postgres dbname=oblast_benchmark sslmode=disable" -func connectToPostgresTestDB(t testing.TB, recordCount int) *sql.DB { +func connectToPostgresTestDB(t testing.TB, recordCount int) oblast.SqlHandle[*sql.DB] { dsn := cmp.Or(os.Getenv("BENCHMARK_POSTGRES_DSN"), defaultPostgresDSN) - db := must.Return(sql.Open("postgres", dsn))(t) - _ = must.Return(db.Exec(`CREATE TEMPORARY TABLE entries (id BIGSERIAL, message TEXT)`))(t) + db := oblast.Wrap(must.Return(sql.Open("postgres", dsn))(t)) + _ = must.Return(db.Base.Exec(`CREATE TEMPORARY TABLE entries (id BIGSERIAL, message TEXT)`))(t) if recordCount > 0 { // fill in some random-looking, but deterministic data - stmt := must.Return(db.Prepare(`INSERT INTO entries (id, message) VALUES ($1, $2)`))(t) + stmt := must.Return(db.Base.Prepare(`INSERT INTO entries (id, message) VALUES ($1, $2)`))(t) for idx := range recordCount { buf := sha256.Sum256([]byte(strconv.Itoa(idx))) _ = must.Return(stmt.Exec(idx, fmt.Sprintf("sha256:%x", buf[:])))(t) @@ -62,11 +62,11 @@ func connectToPgxTestDB(t testing.TB, recordCount int) *pgx.Conn { if recordCount > 0 { // fill in some random-looking, but deterministic data - sql := `INSERT INTO entries (id, message) VALUES ($1, $2)` - stmt := must.Return(conn.Prepare(ctx, sql, sql))(t) + query := `INSERT INTO entries (id, message) VALUES ($1, $2)` + stmt := must.Return(conn.Prepare(ctx, query, query))(t) for idx := range recordCount { buf := sha256.Sum256([]byte(strconv.Itoa(idx))) - _ = must.Return(conn.Exec(ctx, sql, idx, fmt.Sprintf("sha256:%x", buf[:])))(t) + _ = must.Return(conn.Exec(ctx, query, idx, fmt.Sprintf("sha256:%x", buf[:])))(t) } must.Succeed(t, conn.Deallocate(ctx, stmt.Name)) } @@ -76,7 +76,6 @@ func connectToPgxTestDB(t testing.TB, recordCount int) *pgx.Conn { func BenchmarkPostgresSelect(b *testing.B) { pqDB := connectToPostgresTestDB(b, totalRecordCountForSelect) - pqDBH := oblast.Wrap(pqDB) pgxConn := connectToPgxTestDB(b, totalRecordCountForSelect) pgxConnH := oblast_pgx.Wrap(pgxConn) @@ -93,7 +92,7 @@ func BenchmarkPostgresSelect(b *testing.B) { b.Run("driver=pq/strategy=oblast", func(b *testing.B) { for b.Loop() { - records := must.Return(store.Select(noctx, pqDBH, query))(b) + records := must.Return(store.Select(noctx, pqDB, query))(b) assert.Equal(b, len(records), batchSize) } }) @@ -108,7 +107,7 @@ func BenchmarkPostgresSelect(b *testing.B) { b.Run("driver=pq/strategy=straight", func(b *testing.B) { for b.Loop() { var records []OblastEntry - rows := must.Return(pqDB.Query(query))(b) //nolint:rowserrcheck // false positive + rows := must.Return(pqDB.Base.Query(query))(b) //nolint:rowserrcheck // false positive for rows.Next() { var e OblastEntry must.Succeed(b, rows.Scan(&e.ID, &e.Message)) @@ -122,7 +121,7 @@ func BenchmarkPostgresSelect(b *testing.B) { b.Run("driver=pgx/strategy=straight", func(b *testing.B) { for b.Loop() { var records []OblastEntry - rows := must.Return(pgxConn.Query(noctx, query))(b) //nolint:rowserrcheck // false positive + rows := must.Return(pgxConn.Query(noctx, query))(b) for rows.Next() { var e OblastEntry must.Succeed(b, rows.Scan(&e.ID, &e.Message)) @@ -138,7 +137,6 @@ func BenchmarkPostgresSelect(b *testing.B) { func BenchmarkPostgresSelectOne(b *testing.B) { pqDB := connectToPostgresTestDB(b, totalRecordCountForSelect) - pqDBH := oblast.Wrap(pqDB) pgxConn := connectToPgxTestDB(b, totalRecordCountForSelect) pgxConnH := oblast_pgx.Wrap(pgxConn) @@ -157,7 +155,7 @@ func BenchmarkPostgresSelectOne(b *testing.B) { b.Run("driver=pq/strategy=oblast", func(b *testing.B) { for b.Loop() { - r := must.Return(precomputedQuery.SelectOne(noctx, pqDBH))(b) + r := must.Return(precomputedQuery.SelectOne(noctx, pqDB))(b) assert.Equal(b, r.ID, recordID) } }) @@ -175,7 +173,7 @@ func BenchmarkPostgresSelectOne(b *testing.B) { id int64 message string ) - must.Succeed(b, pqDB.QueryRow(query).Scan(&id, &message)) + must.Succeed(b, pqDB.Base.QueryRow(query).Scan(&id, &message)) assert.Equal(b, id, int64(recordID)) } }) @@ -194,7 +192,6 @@ func BenchmarkPostgresSelectOne(b *testing.B) { func BenchmarkPostgresInsertAndDelete(b *testing.B) { pqDB := connectToPostgresTestDB(b, 0) - pqDBH := oblast.Wrap(pqDB) pgxConn := connectToPgxTestDB(b, 0) pgxConnH := oblast_pgx.Wrap(pgxConn) @@ -225,7 +222,7 @@ func BenchmarkPostgresInsertAndDelete(b *testing.B) { b.Run("driver=pq/strategy=oblast", func(b *testing.B) { for b.Loop() { - insertAndDeleteWithOblast(b, pqDBH) + insertAndDeleteWithOblast(b, pqDB) } }) @@ -242,10 +239,10 @@ func BenchmarkPostgresInsertAndDelete(b *testing.B) { for b.Loop() { ids := make([]int64, batchSize) for idx := range ids { - must.Succeed(b, pqDB.QueryRow(insertQuery, "hello").Scan(&ids[idx])) + must.Succeed(b, pqDB.Base.QueryRow(insertQuery, "hello").Scan(&ids[idx])) } for _, id := range ids { - _ = must.Return(pqDB.Exec(deleteQuery, id))(b) + _ = must.Return(pqDB.Base.Exec(deleteQuery, id))(b) } } }) @@ -265,12 +262,12 @@ func BenchmarkPostgresInsertAndDelete(b *testing.B) { b.Run("driver=pq/strategy=prepared", func(b *testing.B) { for b.Loop() { ids := make([]int64, batchSize) - stmtInsert := must.Return(pqDB.Prepare(insertQuery))(b) + stmtInsert := must.Return(pqDB.Base.Prepare(insertQuery))(b) defer stmtInsert.Close() for idx := range ids { must.Succeed(b, stmtInsert.QueryRow("hello").Scan(&ids[idx])) } - stmtDelete := must.Return(pqDB.Prepare(deleteQuery))(b) + stmtDelete := must.Return(pqDB.Base.Prepare(deleteQuery))(b) defer stmtDelete.Close() for _, id := range ids { _ = must.Return(stmtDelete.Exec(id))(b) @@ -299,7 +296,6 @@ func BenchmarkPostgresInsertAndDelete(b *testing.B) { func BenchmarkPostgresUpdate(b *testing.B) { pqDB := connectToPostgresTestDB(b, 0) - pqDBH := oblast.Wrap(pqDB) pgxConn := connectToPgxTestDB(b, 0) pgxConnH := oblast_pgx.Wrap(pgxConn) @@ -313,7 +309,7 @@ func BenchmarkPostgresUpdate(b *testing.B) { for _, batchSize := range batchSizesForInsertDelete { b.Run("N="+strconv.Itoa(batchSize), func(b *testing.B) { // prepare a bunch of records that we can update, in a reproducible way - _ = must.Return(pqDB.Exec(`DELETE FROM entries`)) + _ = must.Return(pqDB.Base.Exec(`DELETE FROM entries`)) _ = must.Return(pgxConn.Exec(noctx, `DELETE FROM entries`)) pqRecords := make([]OblastEntry, batchSize) pqRecordsForInsert := make([]*OblastEntry, batchSize) @@ -325,7 +321,7 @@ func BenchmarkPostgresUpdate(b *testing.B) { pgxRecords[idx] = OblastEntry{Message: "hello"} pgxRecordsForInsert[idx] = &pgxRecords[idx] } - must.Succeed(b, store.Insert(noctx, pqDBH, pqRecordsForInsert...)) + must.Succeed(b, store.Insert(noctx, pqDB, pqRecordsForInsert...)) must.Succeed(b, store.Insert(noctx, pgxConnH, pgxRecordsForInsert...)) // each benchmark will, while looping, write changing values each time in the same way @@ -348,7 +344,7 @@ func BenchmarkPostgresUpdate(b *testing.B) { } b.Run("driver=pq/strategy=oblast", func(b *testing.B) { - loop(b, updateWithOblast(b, pqDBH, pqRecords)) + loop(b, updateWithOblast(b, pqDB, pqRecords)) }) b.Run("driver=pgx/strategy=oblast", func(b *testing.B) { @@ -360,7 +356,7 @@ func BenchmarkPostgresUpdate(b *testing.B) { b.Run("driver=pq/strategy=straight", func(b *testing.B) { loop(b, func(message string) { for _, r := range pqRecords { - _ = must.Return(pqDB.Exec(updateQuery, message, r.ID))(b) + _ = must.Return(pqDB.Base.Exec(updateQuery, message, r.ID))(b) } }) }) @@ -375,7 +371,7 @@ func BenchmarkPostgresUpdate(b *testing.B) { b.Run("driver=pq/strategy=prepared", func(b *testing.B) { loop(b, func(message string) { - stmt := must.Return(pqDB.Prepare(updateQuery))(b) + stmt := must.Return(pqDB.Base.Prepare(updateQuery))(b) for _, r := range pqRecords { _ = must.Return(stmt.Exec(message, r.ID))(b) } diff --git a/handle.go b/handle.go index b7f8608..bb251bb 100644 --- a/handle.go +++ b/handle.go @@ -16,9 +16,25 @@ import ( // Custom implementations of this interface can be used to connect non-std database drivers to Oblast. type Handle = handle.Handle -// StdHandle is an interface covered by both [*sql.DB] and [*sql.Tx]. +// SqlHandle wraps types like [*sql.DB] or [*sql.Tx] into a [Handle] that can be used with Oblast. +type SqlHandle[T SqlExecutor] struct { + // The original database or transaction handle. + // It is safe to read this field to execute operations that Oblast does not handle (e.g. transactions, savepoints or OLAP queries). + Base T + + // If this is not true, then any methods on this type will panic. + // This is just to enforce that the handle is constructed with Wrap(), thus guaranteeing future compatibility if actual important private struct fields are added later. + ok bool +} + +// Wrap converts an [*sql.DB] or [*sql.Tx] into a [Handle] that can be used with Oblast functions. +func Wrap[T SqlExecutor](dbOrTx T) SqlHandle[T] { + return SqlHandle[T]{Base: dbOrTx, ok: true} +} + +// SqlExecutor is an interface covered by both [*sql.DB] and [*sql.Tx]. // It appears in the signature of function [Wrap]. -type StdHandle interface { +type SqlExecutor interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) @@ -27,38 +43,35 @@ type StdHandle interface { // static assertion that the respective types implement the interface var ( - _ StdHandle = &sql.DB{} - _ StdHandle = &sql.Tx{} + _ SqlExecutor = &sql.DB{} + _ SqlExecutor = &sql.Tx{} ) -// Wrap converts an [*sql.DB] or [*sql.Tx] into a [Handle] that can be used with Oblast functions. -func Wrap(dbOrTx StdHandle) Handle { - return wrappedHandle{dbOrTx} -} - -type wrappedHandle struct { - db StdHandle -} - -// Prepare implements the [Handle] interface. -func (h wrappedHandle) Prepare(ctx context.Context, query string, repeated bool) (handle.Statement, error) { +// OblastPrepare implements the [Handle] interface. +func (h SqlHandle[T]) OblastPrepare(ctx context.Context, query string, repeated bool) (handle.Statement, error) { + if !h.ok { + panic("SqlHandle was not constructed through oblast.Wrap()!") + } if !repeated { - return wrappedStatement{h.db, query, nil}, nil + return wrappedStatement{h.Base, query, nil}, nil } - stmt, err := h.db.PrepareContext(ctx, query) + stmt, err := h.Base.PrepareContext(ctx, query) if err != nil { return nil, fmt.Errorf("during Prepare(): %w", err) } - return wrappedStatement{h.db, query, stmt}, nil + return wrappedStatement{h.Base, query, stmt}, nil } -// Query implements the [Handle] interface. -func (h wrappedHandle) Query(ctx context.Context, query string, args []any) (handle.Rows, error) { - return h.db.QueryContext(ctx, query, args...) //nolint:rowserrcheck // the caller does the check +// OblastQuery implements the [Handle] interface. +func (h SqlHandle[T]) OblastQuery(ctx context.Context, query string, args []any) (handle.Rows, error) { + if !h.ok { + panic("SqlHandle was not constructed through oblast.Wrap()!") + } + return h.Base.QueryContext(ctx, query, args...) //nolint:rowserrcheck // the caller does the check } type wrappedStatement struct { - db StdHandle + db SqlExecutor query string stmt *sql.Stmt // nil if repeated = false } diff --git a/handle/handle.go b/handle/handle.go index eaf3558..4af712a 100644 --- a/handle/handle.go +++ b/handle/handle.go @@ -13,16 +13,18 @@ import ( // Handle contains behavior that database handles must offer to Oblast. // The standard-library types [*sql.DB] and [*sql.Tx] can satisfy this interface through the Wrap() function from the main package. // Custom implementations of this interface can be used to connect non-std database drivers to Oblast. +// +// The method names are deliberately clunky to avoid name clashes with well-known methods like [sql.DB.Prepare] or [sql.DB.Query]. type Handle interface { - // Prepare prepares to execute a certain SQL query one or multiple times. + // OblastPrepare prepares to execute a certain SQL query one or multiple times. // // The "repeated" flag is a hint to the implementation whether the same statement is going to be run many times. // If false, the implementation shall choose to forego the additional effort of a full statement preparation if possible, // and execute one-off queries instead. - Prepare(ctx context.Context, query string, repeated bool) (Statement, error) + OblastPrepare(ctx context.Context, query string, repeated bool) (Statement, error) - // Query works like db.QueryContext(ctx, query, args...). - Query(ctx context.Context, query string, args []any) (Rows, error) + // OblastQuery works like db.QueryContext(ctx, query, args...). + OblastQuery(ctx context.Context, query string, args []any) (Rows, error) } // Statement represents a prepared statement returned from [Handle.Prepare]. diff --git a/query.go b/query.go index eea1771..79b7abd 100644 --- a/query.go +++ b/query.go @@ -31,7 +31,7 @@ func prepare(ctx context.Context, db Handle, query, operation string, inputSize return nil, fmt.Errorf("cannot execute %s() because query could not be autogenerated", operation) } - return db.Prepare(ctx, query, inputSize >= PrepareThreshold) + return db.OblastPrepare(ctx, query, inputSize >= PrepareThreshold) } // Insert executes an SQL INSERT statement for each of the provided records. diff --git a/select.go b/select.go index 17195d0..35c0671 100644 --- a/select.go +++ b/select.go @@ -105,7 +105,7 @@ func (s Store[R]) SelectWhere(ctx context.Context, db Handle, partialQuery strin } func startSelectQuery(ctx context.Context, db Handle, plan plan, query string, args ...any) (handle.Rows, [][]int, error) { - rows, err := db.Query(ctx, query, args) + rows, err := db.OblastQuery(ctx, query, args) if err != nil { return nil, nil, fmt.Errorf("during Query(): %w", err) } @@ -136,7 +136,7 @@ func startSelectWhereQuery(ctx context.Context, db Handle, plan plan, partialQue return nil, nil, errors.New("cannot execute SelectWhere() because query could not be autogenerated") } query := plan.Select.Query + partialQuery - rows, err = db.Query(ctx, query, args) + rows, err = db.OblastQuery(ctx, query, args) if err != nil { err = fmt.Errorf("during Query(): %w", err) } @@ -240,7 +240,7 @@ func selectOne(ctx context.Context, db Handle, plan plan, v reflect.Value, query for idx, index := range plan.Select.ScanIndexes { slots[idx] = v.FieldByIndex(index).Addr().Interface() } - stmt, err := db.Prepare(ctx, query, false) + stmt, err := db.OblastPrepare(ctx, query, false) if err != nil { return err } -- cgit v1.2.3