aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStefan Majewsky <majewsky@gmx.net>2026-05-13 01:11:30 +0200
committerStefan Majewsky <majewsky@gmx.net>2026-05-13 01:13:17 +0200
commita561ebb42148c72638f943e44191da07c16df7f6 (patch)
treefb2ecc409fa3c0d39ac8408da95820db8ebebed0
parent2fe6a5a42ccb663211f4f4804b78fff3bd9ebdc0 (diff)
downloadgo-oblast-a561ebb42148c72638f943e44191da07c16df7f6.tar.gz
return a concrete type from Wrap() to enable non-Oblast DB operations
-rw-r--r--CHANGELOG.md6
-rw-r--r--benchmark/benchmark_test.go68
-rw-r--r--benchmark/internal/oblast_pgx/handle.go8
-rw-r--r--benchmark/postgres_test.go50
-rw-r--r--handle.go57
-rw-r--r--handle/handle.go10
-rw-r--r--query.go2
-rw-r--r--select.go6
8 files changed, 110 insertions, 97 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 0f0aa71..a260a2d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,12 @@ SPDX-License-Identifier: Apache-2.0
# v0.8.0 (TBD)
+API changes:
+
+- `Wrap` now returns a struct type `SqlHandle` instead of the interface type `Handle`.
+ This enables reaching into the `SqlHandle` and getting the original `*sql.DB` and `*sql.Tx` back out,
+ which is more ergonomic in situations where Oblast loads/stores need to be mixed with other types of DB operations.
+
Changes:
- Insert, Upsert, Update and Delete will no longer panic when one of the fields they need to access is within a pointer-to-struct that is nil.
diff --git a/benchmark/benchmark_test.go b/benchmark/benchmark_test.go
index da764b2..4bc7950 100644
--- a/benchmark/benchmark_test.go
+++ b/benchmark/benchmark_test.go
@@ -43,14 +43,14 @@ var (
batchSizesForUpdate = []int{1, 2, 4, 8, 16, 100}
)
-func makeSqliteTestDB(t testing.TB, recordCount int) (db *sql.DB, dsn string) {
+func makeSqliteTestDB(t testing.TB, recordCount int) (db oblast.SqlHandle[*sql.DB], dsn string) {
dsn = fmt.Sprintf("file:%s?mode=memory&cache=shared", t.Name())
- db = must.Return(sql.Open("sqlite3", dsn))(t)
- _ = must.Return(db.Exec(`CREATE TABLE entries (id INTEGER, message TEXT, PRIMARY KEY (id AUTOINCREMENT))`))(t)
+ db = oblast.Wrap(must.Return(sql.Open("sqlite3", dsn))(t))
+ _ = must.Return(db.Base.Exec(`CREATE TABLE entries (id INTEGER, message TEXT, PRIMARY KEY (id AUTOINCREMENT))`))(t)
if recordCount > 0 {
// fill in some random-looking, but deterministic data
- stmt := must.Return(db.Prepare(`INSERT INTO entries (id, message) VALUES (?, ?)`))(t)
+ stmt := must.Return(db.Base.Prepare(`INSERT INTO entries (id, message) VALUES (?, ?)`))(t)
for idx := range recordCount {
buf := sha256.Sum256([]byte(strconv.Itoa(idx)))
_ = must.Return(stmt.Exec(idx, fmt.Sprintf("sha256:%x", buf[:])))(t)
@@ -80,7 +80,6 @@ func (GormEntry) TableName() string { return "entries" }
func BenchmarkORMSelectMany(b *testing.B) {
db, dsn := makeSqliteTestDB(b, totalRecordCountForSelect)
- dbh := oblast.Wrap(db)
// test with different sizes of resultsets (N=1 is an OLTP-like workload,
// then the larger N lean more towards the OLAP side of things)
@@ -92,19 +91,19 @@ func BenchmarkORMSelectMany(b *testing.B) {
oblast.TableNameIs("entries"),
oblast.PrimaryKeyIs("id"),
)
- gorpDB := gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}}
+ gorpDB := gorp.DbMap{Db: db.Base, Dialect: gorp.SqliteDialect{}}
gormDB := must.Return(gorm.Open(sqlite.Open(dsn), &gorm.Config{}))(b)
partialQuery := `id < ` + strconv.Itoa(batchSize)
query := `SELECT * FROM entries WHERE ` + partialQuery
precomputedQuery := store.MustPrepareSelectQueryWhere(partialQuery)
selectWithOblast := func(b *testing.B) {
- records := must.Return(store.Select(noctx, dbh, query))(b)
+ records := must.Return(store.Select(noctx, db, query))(b)
assert.Equal(b, len(records), batchSize)
}
selectWithOblastWhere := func(b *testing.B) {
- records := must.Return(precomputedQuery.Select(noctx, dbh))(b)
+ records := must.Return(precomputedQuery.Select(noctx, db))(b)
assert.Equal(b, len(records), batchSize)
}
@@ -121,7 +120,7 @@ func BenchmarkORMSelectMany(b *testing.B) {
selectWithSqlite := func(b *testing.B) {
var count int
- rows := must.Return(db.Query(query))(b) //nolint:rowserrcheck // false positive
+ rows := must.Return(db.Base.Query(query))(b) //nolint:rowserrcheck // false positive
var (
id int64
message string
@@ -176,7 +175,6 @@ func BenchmarkORMSelectMany(b *testing.B) {
func BenchmarkORMSelectOne(b *testing.B) {
db, dsn := makeSqliteTestDB(b, totalRecordCountForSelect)
- dbh := oblast.Wrap(db)
// grab a "random" record from the DB, not just the first or the last
recordID := min(totalRecordCountForSelect*2/3, totalRecordCountForSelect)
@@ -187,19 +185,19 @@ func BenchmarkORMSelectOne(b *testing.B) {
oblast.TableNameIs("entries"),
oblast.PrimaryKeyIs("id"),
)
- gorpDB := gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}}
+ gorpDB := gorp.DbMap{Db: db.Base, Dialect: gorp.SqliteDialect{}}
gormDB := must.Return(gorm.Open(sqlite.Open(dsn), &gorm.Config{}))(b)
partialQuery := `id = ` + strconv.Itoa(recordID)
query := `SELECT * FROM entries WHERE ` + partialQuery
precomputedQuery := store.MustPrepareSelectQueryWhere(partialQuery)
selectWithOblast := func(b *testing.B) {
- r := must.Return(store.SelectOne(noctx, dbh, query))(b)
+ r := must.Return(store.SelectOne(noctx, db, query))(b)
assert.Equal(b, r.ID, recordID)
}
selectWithOblastWhere := func(b *testing.B) {
- r := must.Return(precomputedQuery.SelectOne(noctx, dbh))(b)
+ r := must.Return(precomputedQuery.SelectOne(noctx, db))(b)
assert.Equal(b, r.ID, recordID)
}
@@ -219,7 +217,7 @@ func BenchmarkORMSelectOne(b *testing.B) {
id int64
message string
)
- must.Succeed(b, db.QueryRow(query).Scan(&id, &message))
+ must.Succeed(b, db.Base.QueryRow(query).Scan(&id, &message))
assert.Equal(b, id, int64(recordID))
}
@@ -261,14 +259,13 @@ func BenchmarkORMSelectOne(b *testing.B) {
func BenchmarkORMInsertAndDelete(b *testing.B) {
db, dsn := makeSqliteTestDB(b, 0)
- dbh := oblast.Wrap(db)
store := oblast.MustNewStore[OblastEntry](
oblast.SqliteDialect(),
oblast.TableNameIs("entries"),
oblast.PrimaryKeyIs("id"),
)
- gorpDB := gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}}
+ gorpDB := gorp.DbMap{Db: db.Base, Dialect: gorp.SqliteDialect{}}
gorpDB.AddTableWithName(GorpEntry{}, "entries").SetKeys(true, "id")
gormDB := must.Return(gorm.Open(sqlite.Open(dsn), &gorm.Config{}))(b)
@@ -283,22 +280,22 @@ func BenchmarkORMInsertAndDelete(b *testing.B) {
records[idx] = OblastEntry{Message: "hello"}
recordsForInsert[idx] = &records[idx]
}
- must.Succeed(b, store.Insert(noctx, dbh, recordsForInsert...))
+ must.Succeed(b, store.Insert(noctx, db, recordsForInsert...))
for _, r := range records {
if r.ID == 0 {
b.Errorf("ID was not filled!")
}
}
- must.Succeed(b, store.Delete(noctx, dbh, records...))
+ must.Succeed(b, store.Delete(noctx, db, records...))
}
if batchSize == 1 {
insertAndDeleteWithOblast = func(b *testing.B) {
record := OblastEntry{Message: "hello"}
- must.Succeed(b, store.Insert(noctx, dbh, &record))
+ must.Succeed(b, store.Insert(noctx, db, &record))
if record.ID == 0 {
b.Errorf("ID was not filled!")
}
- must.Succeed(b, store.Delete(noctx, dbh, record))
+ must.Succeed(b, store.Delete(noctx, db, record))
}
}
@@ -354,23 +351,23 @@ func BenchmarkORMInsertAndDelete(b *testing.B) {
insertAndDeleteWithStraightExec := func(b *testing.B) {
ids := make([]int64, batchSize)
for idx := range ids {
- result := must.Return(db.Exec(`INSERT INTO entries (message) VALUES (?)`, "hello"))(b)
+ result := must.Return(db.Base.Exec(`INSERT INTO entries (message) VALUES (?)`, "hello"))(b)
ids[idx] = must.Return(result.LastInsertId())(b)
}
for _, id := range ids {
- _ = must.Return(db.Exec(`DELETE FROM entries WHERE id = ?`, id))(b)
+ _ = must.Return(db.Base.Exec(`DELETE FROM entries WHERE id = ?`, id))(b)
}
}
insertAndDeleteWithPreparedExec := func(b *testing.B) {
ids := make([]int64, batchSize)
- stmtInsert := must.Return(db.Prepare(`INSERT INTO entries (message) VALUES (?)`))(b)
+ stmtInsert := must.Return(db.Base.Prepare(`INSERT INTO entries (message) VALUES (?)`))(b)
defer stmtInsert.Close()
for idx := range ids {
result := must.Return(stmtInsert.Exec("hello"))(b)
ids[idx] = must.Return(result.LastInsertId())(b)
}
- stmtDelete := must.Return(db.Prepare(`DELETE FROM entries WHERE id = ?`))(b)
+ stmtDelete := must.Return(db.Base.Prepare(`DELETE FROM entries WHERE id = ?`))(b)
defer stmtDelete.Close()
for _, id := range ids {
_ = must.Return(stmtDelete.Exec(id))(b)
@@ -380,21 +377,21 @@ func BenchmarkORMInsertAndDelete(b *testing.B) {
insertAndDeleteWithStraightQueryRow := func(b *testing.B) {
ids := make([]int64, batchSize)
for idx := range ids {
- must.Succeed(b, db.QueryRow(`INSERT INTO entries (message) VALUES (?) RETURNING id`, "hello").Scan(&ids[idx]))
+ must.Succeed(b, db.Base.QueryRow(`INSERT INTO entries (message) VALUES (?) RETURNING id`, "hello").Scan(&ids[idx]))
}
for _, id := range ids {
- _ = must.Return(db.Exec(`DELETE FROM entries WHERE id = ?`, id))(b)
+ _ = must.Return(db.Base.Exec(`DELETE FROM entries WHERE id = ?`, id))(b)
}
}
insertAndDeleteWithPreparedQueryRow := func(b *testing.B) {
ids := make([]int64, batchSize)
- stmtInsert := must.Return(db.Prepare(`INSERT INTO entries (message) VALUES (?) RETURNING id`))(b)
+ stmtInsert := must.Return(db.Base.Prepare(`INSERT INTO entries (message) VALUES (?) RETURNING id`))(b)
defer stmtInsert.Close()
for idx := range ids {
must.Succeed(b, stmtInsert.QueryRow("hello").Scan(&ids[idx]))
}
- stmtDelete := must.Return(db.Prepare(`DELETE FROM entries WHERE id = ?`))(b)
+ stmtDelete := must.Return(db.Base.Prepare(`DELETE FROM entries WHERE id = ?`))(b)
defer stmtDelete.Close()
for _, id := range ids {
_ = must.Return(stmtDelete.Exec(id))(b)
@@ -447,14 +444,13 @@ func BenchmarkORMInsertAndDelete(b *testing.B) {
func BenchmarkORMUpdate(b *testing.B) {
db, dsn := makeSqliteTestDB(b, 0)
- dbh := oblast.Wrap(db)
store := oblast.MustNewStore[OblastEntry](
oblast.SqliteDialect(),
oblast.TableNameIs("entries"),
oblast.PrimaryKeyIs("id"),
)
- gorpDB := gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}}
+ gorpDB := gorp.DbMap{Db: db.Base, Dialect: gorp.SqliteDialect{}}
gorpDB.AddTableWithName(GorpEntry{}, "entries").SetKeys(true, "id")
gormDB := must.Return(gorm.Open(sqlite.Open(dsn), &gorm.Config{}))(b)
@@ -462,14 +458,14 @@ func BenchmarkORMUpdate(b *testing.B) {
for _, batchSize := range batchSizesForUpdate {
b.Run("N="+strconv.Itoa(batchSize), func(b *testing.B) {
// prepare a bunch of records that we can update, in a reproducible way
- _ = must.Return(db.Exec(`DELETE FROM entries`))
+ _ = must.Return(db.Base.Exec(`DELETE FROM entries`))
recordsForOblast := make([]OblastEntry, batchSize)
recordsForOblastForInsert := make([]*OblastEntry, batchSize)
for idx := range recordsForOblast {
recordsForOblast[idx] = OblastEntry{Message: "hello"}
recordsForOblastForInsert[idx] = &recordsForOblast[idx]
}
- must.Succeed(b, store.Insert(noctx, dbh, recordsForOblastForInsert...))
+ must.Succeed(b, store.Insert(noctx, db, recordsForOblastForInsert...))
recordsForGorp := make([]any, batchSize)
for idx, r := range recordsForOblast {
recordsForGorp[idx] = new(GorpEntry(r))
@@ -484,7 +480,7 @@ func BenchmarkORMUpdate(b *testing.B) {
for idx := range recordsForOblast {
recordsForOblast[idx].Message = message
}
- must.Succeed(b, store.Update(noctx, dbh, recordsForOblast...))
+ must.Succeed(b, store.Update(noctx, db, recordsForOblast...))
}
updateWithGorp := func(b *testing.B, message string) {
for _, r := range recordsForGorp {
@@ -502,11 +498,11 @@ func BenchmarkORMUpdate(b *testing.B) {
}
updateWithStraightSqlite := func(b *testing.B, message string) {
for _, r := range recordsForOblast {
- _ = must.Return(db.Exec(`UPDATE entries SET message = ? WHERE id = ?`, message, r.ID))(b)
+ _ = must.Return(db.Base.Exec(`UPDATE entries SET message = ? WHERE id = ?`, message, r.ID))(b)
}
}
updateWithPreparedSqlite := func(b *testing.B, message string) {
- stmt := must.Return(db.Prepare(`UPDATE entries SET message = ? WHERE id = ?`))(b)
+ stmt := must.Return(db.Base.Prepare(`UPDATE entries SET message = ? WHERE id = ?`))(b)
for _, r := range recordsForOblast {
_ = must.Return(stmt.Exec(message, r.ID))(b)
}
@@ -514,7 +510,7 @@ func BenchmarkORMUpdate(b *testing.B) {
}
checkRecordsUpdated := func(b *testing.B, message string) {
var count int64
- must.Succeed(b, db.QueryRow(`SELECT COUNT(*) FROM entries WHERE message = ?`, message).Scan(&count))
+ must.Succeed(b, db.Base.QueryRow(`SELECT COUNT(*) FROM entries WHERE message = ?`, message).Scan(&count))
assert.Equal(b, count, int64(batchSize))
}
diff --git a/benchmark/internal/oblast_pgx/handle.go b/benchmark/internal/oblast_pgx/handle.go
index 6a88e2b..7ccc9ea 100644
--- a/benchmark/internal/oblast_pgx/handle.go
+++ b/benchmark/internal/oblast_pgx/handle.go
@@ -43,8 +43,8 @@ type wrappedHandle struct {
inner Handle
}
-// Prepare implements the [handle.Handle] interface.
-func (h wrappedHandle) Prepare(ctx context.Context, query string, repeated bool) (handle.Statement, error) {
+// OblastPrepare implements the [handle.Handle] interface.
+func (h wrappedHandle) OblastPrepare(ctx context.Context, query string, repeated bool) (handle.Statement, error) {
if !repeated {
return wrappedUnpreparedStatement{query, h.inner}, nil
}
@@ -74,8 +74,8 @@ func deallocate(ctx context.Context, h Handle, stmt *pgconn.StatementDescription
}
}
-// Query implements the [handle.Handle] interface.
-func (h wrappedHandle) Query(ctx context.Context, query string, args []any) (handle.Rows, error) {
+// OblastQuery implements the [handle.Handle] interface.
+func (h wrappedHandle) OblastQuery(ctx context.Context, query string, args []any) (handle.Rows, error) {
rows, err := h.inner.Query(ctx, query, args...)
return wrappedRows{rows}, err
}
diff --git a/benchmark/postgres_test.go b/benchmark/postgres_test.go
index 320ea2a..02c2c43 100644
--- a/benchmark/postgres_test.go
+++ b/benchmark/postgres_test.go
@@ -21,7 +21,7 @@ import (
"go.xyrillian.de/oblast/internal/testhelpers/must"
)
-// NOTE: In this file, we benchmark different PostgreSQL database drivers against each other with or without Oblast inbetween.
+// NOTE: In this file, we benchmark different PostgreSQL database drivers against each other with or without Oblast in between.
// All benchmarks are called "BenchmarkPostgres...".
// To run these benchmarks, you need to have provide a DSN to a PostgreSQL database in $BENCHMARK_POSTGRES_DSN.
@@ -36,14 +36,14 @@ func BenchmarkPostgresHeadingHeadingHeadingHeadingHeadingHeadingHeadingHeading(b
const defaultPostgresDSN = "host=localhost user=postgres dbname=oblast_benchmark sslmode=disable"
-func connectToPostgresTestDB(t testing.TB, recordCount int) *sql.DB {
+func connectToPostgresTestDB(t testing.TB, recordCount int) oblast.SqlHandle[*sql.DB] {
dsn := cmp.Or(os.Getenv("BENCHMARK_POSTGRES_DSN"), defaultPostgresDSN)
- db := must.Return(sql.Open("postgres", dsn))(t)
- _ = must.Return(db.Exec(`CREATE TEMPORARY TABLE entries (id BIGSERIAL, message TEXT)`))(t)
+ db := oblast.Wrap(must.Return(sql.Open("postgres", dsn))(t))
+ _ = must.Return(db.Base.Exec(`CREATE TEMPORARY TABLE entries (id BIGSERIAL, message TEXT)`))(t)
if recordCount > 0 {
// fill in some random-looking, but deterministic data
- stmt := must.Return(db.Prepare(`INSERT INTO entries (id, message) VALUES ($1, $2)`))(t)
+ stmt := must.Return(db.Base.Prepare(`INSERT INTO entries (id, message) VALUES ($1, $2)`))(t)
for idx := range recordCount {
buf := sha256.Sum256([]byte(strconv.Itoa(idx)))
_ = must.Return(stmt.Exec(idx, fmt.Sprintf("sha256:%x", buf[:])))(t)
@@ -62,11 +62,11 @@ func connectToPgxTestDB(t testing.TB, recordCount int) *pgx.Conn {
if recordCount > 0 {
// fill in some random-looking, but deterministic data
- sql := `INSERT INTO entries (id, message) VALUES ($1, $2)`
- stmt := must.Return(conn.Prepare(ctx, sql, sql))(t)
+ query := `INSERT INTO entries (id, message) VALUES ($1, $2)`
+ stmt := must.Return(conn.Prepare(ctx, query, query))(t)
for idx := range recordCount {
buf := sha256.Sum256([]byte(strconv.Itoa(idx)))
- _ = must.Return(conn.Exec(ctx, sql, idx, fmt.Sprintf("sha256:%x", buf[:])))(t)
+ _ = must.Return(conn.Exec(ctx, query, idx, fmt.Sprintf("sha256:%x", buf[:])))(t)
}
must.Succeed(t, conn.Deallocate(ctx, stmt.Name))
}
@@ -76,7 +76,6 @@ func connectToPgxTestDB(t testing.TB, recordCount int) *pgx.Conn {
func BenchmarkPostgresSelect(b *testing.B) {
pqDB := connectToPostgresTestDB(b, totalRecordCountForSelect)
- pqDBH := oblast.Wrap(pqDB)
pgxConn := connectToPgxTestDB(b, totalRecordCountForSelect)
pgxConnH := oblast_pgx.Wrap(pgxConn)
@@ -93,7 +92,7 @@ func BenchmarkPostgresSelect(b *testing.B) {
b.Run("driver=pq/strategy=oblast", func(b *testing.B) {
for b.Loop() {
- records := must.Return(store.Select(noctx, pqDBH, query))(b)
+ records := must.Return(store.Select(noctx, pqDB, query))(b)
assert.Equal(b, len(records), batchSize)
}
})
@@ -108,7 +107,7 @@ func BenchmarkPostgresSelect(b *testing.B) {
b.Run("driver=pq/strategy=straight", func(b *testing.B) {
for b.Loop() {
var records []OblastEntry
- rows := must.Return(pqDB.Query(query))(b) //nolint:rowserrcheck // false positive
+ rows := must.Return(pqDB.Base.Query(query))(b) //nolint:rowserrcheck // false positive
for rows.Next() {
var e OblastEntry
must.Succeed(b, rows.Scan(&e.ID, &e.Message))
@@ -122,7 +121,7 @@ func BenchmarkPostgresSelect(b *testing.B) {
b.Run("driver=pgx/strategy=straight", func(b *testing.B) {
for b.Loop() {
var records []OblastEntry
- rows := must.Return(pgxConn.Query(noctx, query))(b) //nolint:rowserrcheck // false positive
+ rows := must.Return(pgxConn.Query(noctx, query))(b)
for rows.Next() {
var e OblastEntry
must.Succeed(b, rows.Scan(&e.ID, &e.Message))
@@ -138,7 +137,6 @@ func BenchmarkPostgresSelect(b *testing.B) {
func BenchmarkPostgresSelectOne(b *testing.B) {
pqDB := connectToPostgresTestDB(b, totalRecordCountForSelect)
- pqDBH := oblast.Wrap(pqDB)
pgxConn := connectToPgxTestDB(b, totalRecordCountForSelect)
pgxConnH := oblast_pgx.Wrap(pgxConn)
@@ -157,7 +155,7 @@ func BenchmarkPostgresSelectOne(b *testing.B) {
b.Run("driver=pq/strategy=oblast", func(b *testing.B) {
for b.Loop() {
- r := must.Return(precomputedQuery.SelectOne(noctx, pqDBH))(b)
+ r := must.Return(precomputedQuery.SelectOne(noctx, pqDB))(b)
assert.Equal(b, r.ID, recordID)
}
})
@@ -175,7 +173,7 @@ func BenchmarkPostgresSelectOne(b *testing.B) {
id int64
message string
)
- must.Succeed(b, pqDB.QueryRow(query).Scan(&id, &message))
+ must.Succeed(b, pqDB.Base.QueryRow(query).Scan(&id, &message))
assert.Equal(b, id, int64(recordID))
}
})
@@ -194,7 +192,6 @@ func BenchmarkPostgresSelectOne(b *testing.B) {
func BenchmarkPostgresInsertAndDelete(b *testing.B) {
pqDB := connectToPostgresTestDB(b, 0)
- pqDBH := oblast.Wrap(pqDB)
pgxConn := connectToPgxTestDB(b, 0)
pgxConnH := oblast_pgx.Wrap(pgxConn)
@@ -225,7 +222,7 @@ func BenchmarkPostgresInsertAndDelete(b *testing.B) {
b.Run("driver=pq/strategy=oblast", func(b *testing.B) {
for b.Loop() {
- insertAndDeleteWithOblast(b, pqDBH)
+ insertAndDeleteWithOblast(b, pqDB)
}
})
@@ -242,10 +239,10 @@ func BenchmarkPostgresInsertAndDelete(b *testing.B) {
for b.Loop() {
ids := make([]int64, batchSize)
for idx := range ids {
- must.Succeed(b, pqDB.QueryRow(insertQuery, "hello").Scan(&ids[idx]))
+ must.Succeed(b, pqDB.Base.QueryRow(insertQuery, "hello").Scan(&ids[idx]))
}
for _, id := range ids {
- _ = must.Return(pqDB.Exec(deleteQuery, id))(b)
+ _ = must.Return(pqDB.Base.Exec(deleteQuery, id))(b)
}
}
})
@@ -265,12 +262,12 @@ func BenchmarkPostgresInsertAndDelete(b *testing.B) {
b.Run("driver=pq/strategy=prepared", func(b *testing.B) {
for b.Loop() {
ids := make([]int64, batchSize)
- stmtInsert := must.Return(pqDB.Prepare(insertQuery))(b)
+ stmtInsert := must.Return(pqDB.Base.Prepare(insertQuery))(b)
defer stmtInsert.Close()
for idx := range ids {
must.Succeed(b, stmtInsert.QueryRow("hello").Scan(&ids[idx]))
}
- stmtDelete := must.Return(pqDB.Prepare(deleteQuery))(b)
+ stmtDelete := must.Return(pqDB.Base.Prepare(deleteQuery))(b)
defer stmtDelete.Close()
for _, id := range ids {
_ = must.Return(stmtDelete.Exec(id))(b)
@@ -299,7 +296,6 @@ func BenchmarkPostgresInsertAndDelete(b *testing.B) {
func BenchmarkPostgresUpdate(b *testing.B) {
pqDB := connectToPostgresTestDB(b, 0)
- pqDBH := oblast.Wrap(pqDB)
pgxConn := connectToPgxTestDB(b, 0)
pgxConnH := oblast_pgx.Wrap(pgxConn)
@@ -313,7 +309,7 @@ func BenchmarkPostgresUpdate(b *testing.B) {
for _, batchSize := range batchSizesForInsertDelete {
b.Run("N="+strconv.Itoa(batchSize), func(b *testing.B) {
// prepare a bunch of records that we can update, in a reproducible way
- _ = must.Return(pqDB.Exec(`DELETE FROM entries`))
+ _ = must.Return(pqDB.Base.Exec(`DELETE FROM entries`))
_ = must.Return(pgxConn.Exec(noctx, `DELETE FROM entries`))
pqRecords := make([]OblastEntry, batchSize)
pqRecordsForInsert := make([]*OblastEntry, batchSize)
@@ -325,7 +321,7 @@ func BenchmarkPostgresUpdate(b *testing.B) {
pgxRecords[idx] = OblastEntry{Message: "hello"}
pgxRecordsForInsert[idx] = &pgxRecords[idx]
}
- must.Succeed(b, store.Insert(noctx, pqDBH, pqRecordsForInsert...))
+ must.Succeed(b, store.Insert(noctx, pqDB, pqRecordsForInsert...))
must.Succeed(b, store.Insert(noctx, pgxConnH, pgxRecordsForInsert...))
// each benchmark will, while looping, write changing values each time in the same way
@@ -348,7 +344,7 @@ func BenchmarkPostgresUpdate(b *testing.B) {
}
b.Run("driver=pq/strategy=oblast", func(b *testing.B) {
- loop(b, updateWithOblast(b, pqDBH, pqRecords))
+ loop(b, updateWithOblast(b, pqDB, pqRecords))
})
b.Run("driver=pgx/strategy=oblast", func(b *testing.B) {
@@ -360,7 +356,7 @@ func BenchmarkPostgresUpdate(b *testing.B) {
b.Run("driver=pq/strategy=straight", func(b *testing.B) {
loop(b, func(message string) {
for _, r := range pqRecords {
- _ = must.Return(pqDB.Exec(updateQuery, message, r.ID))(b)
+ _ = must.Return(pqDB.Base.Exec(updateQuery, message, r.ID))(b)
}
})
})
@@ -375,7 +371,7 @@ func BenchmarkPostgresUpdate(b *testing.B) {
b.Run("driver=pq/strategy=prepared", func(b *testing.B) {
loop(b, func(message string) {
- stmt := must.Return(pqDB.Prepare(updateQuery))(b)
+ stmt := must.Return(pqDB.Base.Prepare(updateQuery))(b)
for _, r := range pqRecords {
_ = must.Return(stmt.Exec(message, r.ID))(b)
}
diff --git a/handle.go b/handle.go
index b7f8608..bb251bb 100644
--- a/handle.go
+++ b/handle.go
@@ -16,9 +16,25 @@ import (
// Custom implementations of this interface can be used to connect non-std database drivers to Oblast.
type Handle = handle.Handle
-// StdHandle is an interface covered by both [*sql.DB] and [*sql.Tx].
+// SqlHandle wraps types like [*sql.DB] or [*sql.Tx] into a [Handle] that can be used with Oblast.
+type SqlHandle[T SqlExecutor] struct {
+ // The original database or transaction handle.
+ // It is safe to read this field to execute operations that Oblast does not handle (e.g. transactions, savepoints or OLAP queries).
+ Base T
+
+ // If this is not true, then any methods on this type will panic.
+ // This is just to enforce that the handle is constructed with Wrap(), thus guaranteeing future compatibility if actual important private struct fields are added later.
+ ok bool
+}
+
+// Wrap converts an [*sql.DB] or [*sql.Tx] into a [Handle] that can be used with Oblast functions.
+func Wrap[T SqlExecutor](dbOrTx T) SqlHandle[T] {
+ return SqlHandle[T]{Base: dbOrTx, ok: true}
+}
+
+// SqlExecutor is an interface covered by both [*sql.DB] and [*sql.Tx].
// It appears in the signature of function [Wrap].
-type StdHandle interface {
+type SqlExecutor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
@@ -27,38 +43,35 @@ type StdHandle interface {
// static assertion that the respective types implement the interface
var (
- _ StdHandle = &sql.DB{}
- _ StdHandle = &sql.Tx{}
+ _ SqlExecutor = &sql.DB{}
+ _ SqlExecutor = &sql.Tx{}
)
-// Wrap converts an [*sql.DB] or [*sql.Tx] into a [Handle] that can be used with Oblast functions.
-func Wrap(dbOrTx StdHandle) Handle {
- return wrappedHandle{dbOrTx}
-}
-
-type wrappedHandle struct {
- db StdHandle
-}
-
-// Prepare implements the [Handle] interface.
-func (h wrappedHandle) Prepare(ctx context.Context, query string, repeated bool) (handle.Statement, error) {
+// OblastPrepare implements the [Handle] interface.
+func (h SqlHandle[T]) OblastPrepare(ctx context.Context, query string, repeated bool) (handle.Statement, error) {
+ if !h.ok {
+ panic("SqlHandle was not constructed through oblast.Wrap()!")
+ }
if !repeated {
- return wrappedStatement{h.db, query, nil}, nil
+ return wrappedStatement{h.Base, query, nil}, nil
}
- stmt, err := h.db.PrepareContext(ctx, query)
+ stmt, err := h.Base.PrepareContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("during Prepare(): %w", err)
}
- return wrappedStatement{h.db, query, stmt}, nil
+ return wrappedStatement{h.Base, query, stmt}, nil
}
-// Query implements the [Handle] interface.
-func (h wrappedHandle) Query(ctx context.Context, query string, args []any) (handle.Rows, error) {
- return h.db.QueryContext(ctx, query, args...) //nolint:rowserrcheck // the caller does the check
+// OblastQuery implements the [Handle] interface.
+func (h SqlHandle[T]) OblastQuery(ctx context.Context, query string, args []any) (handle.Rows, error) {
+ if !h.ok {
+ panic("SqlHandle was not constructed through oblast.Wrap()!")
+ }
+ return h.Base.QueryContext(ctx, query, args...) //nolint:rowserrcheck // the caller does the check
}
type wrappedStatement struct {
- db StdHandle
+ db SqlExecutor
query string
stmt *sql.Stmt // nil if repeated = false
}
diff --git a/handle/handle.go b/handle/handle.go
index eaf3558..4af712a 100644
--- a/handle/handle.go
+++ b/handle/handle.go
@@ -13,16 +13,18 @@ import (
// Handle contains behavior that database handles must offer to Oblast.
// The standard-library types [*sql.DB] and [*sql.Tx] can satisfy this interface through the Wrap() function from the main package.
// Custom implementations of this interface can be used to connect non-std database drivers to Oblast.
+//
+// The method names are deliberately clunky to avoid name clashes with well-known methods like [sql.DB.Prepare] or [sql.DB.Query].
type Handle interface {
- // Prepare prepares to execute a certain SQL query one or multiple times.
+ // OblastPrepare prepares to execute a certain SQL query one or multiple times.
//
// The "repeated" flag is a hint to the implementation whether the same statement is going to be run many times.
// If false, the implementation shall choose to forego the additional effort of a full statement preparation if possible,
// and execute one-off queries instead.
- Prepare(ctx context.Context, query string, repeated bool) (Statement, error)
+ OblastPrepare(ctx context.Context, query string, repeated bool) (Statement, error)
- // Query works like db.QueryContext(ctx, query, args...).
- Query(ctx context.Context, query string, args []any) (Rows, error)
+ // OblastQuery works like db.QueryContext(ctx, query, args...).
+ OblastQuery(ctx context.Context, query string, args []any) (Rows, error)
}
// Statement represents a prepared statement returned from [Handle.Prepare].
diff --git a/query.go b/query.go
index eea1771..79b7abd 100644
--- a/query.go
+++ b/query.go
@@ -31,7 +31,7 @@ func prepare(ctx context.Context, db Handle, query, operation string, inputSize
return nil, fmt.Errorf("cannot execute %s() because query could not be autogenerated", operation)
}
- return db.Prepare(ctx, query, inputSize >= PrepareThreshold)
+ return db.OblastPrepare(ctx, query, inputSize >= PrepareThreshold)
}
// Insert executes an SQL INSERT statement for each of the provided records.
diff --git a/select.go b/select.go
index 17195d0..35c0671 100644
--- a/select.go
+++ b/select.go
@@ -105,7 +105,7 @@ func (s Store[R]) SelectWhere(ctx context.Context, db Handle, partialQuery strin
}
func startSelectQuery(ctx context.Context, db Handle, plan plan, query string, args ...any) (handle.Rows, [][]int, error) {
- rows, err := db.Query(ctx, query, args)
+ rows, err := db.OblastQuery(ctx, query, args)
if err != nil {
return nil, nil, fmt.Errorf("during Query(): %w", err)
}
@@ -136,7 +136,7 @@ func startSelectWhereQuery(ctx context.Context, db Handle, plan plan, partialQue
return nil, nil, errors.New("cannot execute SelectWhere() because query could not be autogenerated")
}
query := plan.Select.Query + partialQuery
- rows, err = db.Query(ctx, query, args)
+ rows, err = db.OblastQuery(ctx, query, args)
if err != nil {
err = fmt.Errorf("during Query(): %w", err)
}
@@ -240,7 +240,7 @@ func selectOne(ctx context.Context, db Handle, plan plan, v reflect.Value, query
for idx, index := range plan.Select.ScanIndexes {
slots[idx] = v.FieldByIndex(index).Addr().Interface()
}
- stmt, err := db.Prepare(ctx, query, false)
+ stmt, err := db.OblastPrepare(ctx, query, false)
if err != nil {
return err
}