diff options
| -rw-r--r-- | CHANGELOG.md | 11 | ||||
| -rw-r--r-- | benchmark/.gitignore | 3 | ||||
| -rw-r--r-- | benchmark/benchmark_test.go | 24 | ||||
| -rw-r--r-- | handle.go | 90 | ||||
| -rw-r--r-- | handle/handle.go | 52 | ||||
| -rw-r--r-- | oblast.go | 20 | ||||
| -rw-r--r-- | query.go | 65 | ||||
| -rw-r--r-- | query_test.go | 20 | ||||
| -rw-r--r-- | runtimeindex_test.go | 2 | ||||
| -rw-r--r-- | select.go | 18 | ||||
| -rw-r--r-- | select_test.go | 16 |
11 files changed, 216 insertions, 105 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index e16f018..d1f54cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,17 @@ SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net> SPDX-License-Identifier: Apache-2.0 --> +# v0.7.0 (TBD) + +API changes: + +- The `Handle` type changes to not cover `*sql.DB` and `*sql.Tx` directly, thus removing the direct dependency on database/sql. + This adds a new memory allocation (for wrapping `*sql.DB` or `*sql.Tx` in the wrapper implementing `Handle`) + and some CPU overhead because of the interface indirection, but I consider this a worthwhile tradeoff + to enable the use of non-standard database drivers like <https://github.com/jackc/pgx> + (if the user provides the respective custom implementation of the `Handle` interface). + Preliminary benchmarking has already shown that, for the PostgreSQL case, oblast + jackc/pgx is significantly more efficient than oblast + lib/pq. + # v0.6.0 (2026-05-08) API changes: diff --git a/benchmark/.gitignore b/benchmark/.gitignore new file mode 100644 index 0000000..ed6d513 --- /dev/null +++ b/benchmark/.gitignore @@ -0,0 +1,3 @@ +# artifacts from running e.g. `go test -bench . -benchmem -memprofile mem.out` +/benchmark.test +/*.out diff --git a/benchmark/benchmark_test.go b/benchmark/benchmark_test.go index 721d3a2..f0967b9 100644 --- a/benchmark/benchmark_test.go +++ b/benchmark/benchmark_test.go @@ -77,6 +77,7 @@ func (GormEntry) TableName() string { return "entries" } func BenchmarkSelectMany(b *testing.B) { db, dsn := makeTestDB(b, totalRecordCountForSelect) + dbh := oblast.Wrap(db) // test with different sizes of resultsets (N=1 is an OLTP-like workload, // then the larger N lean more towards the OLAP side of things) @@ -95,12 +96,12 @@ func BenchmarkSelectMany(b *testing.B) { precomputedQuery := store.MustPrepareSelectQueryWhere(partialQuery) selectWithOblast := func(b *testing.B) { - records := must.Return(store.Select(noctx, db, query))(b) + records := must.Return(store.Select(noctx, dbh, query))(b) assert.Equal(b, len(records), batchSize) } selectWithOblastWhere := func(b *testing.B) { - records := must.Return(precomputedQuery.Select(noctx, db))(b) + records := must.Return(precomputedQuery.Select(noctx, dbh))(b) assert.Equal(b, len(records), batchSize) } @@ -172,6 +173,7 @@ func BenchmarkSelectMany(b *testing.B) { func BenchmarkSelectOne(b *testing.B) { db, dsn := makeTestDB(b, totalRecordCountForSelect) + dbh := oblast.Wrap(db) // grab a "random" record from the DB, not just the first or the last recordID := min(totalRecordCountForSelect*2/3, totalRecordCountForSelect) @@ -189,12 +191,12 @@ func BenchmarkSelectOne(b *testing.B) { precomputedQuery := store.MustPrepareSelectQueryWhere(partialQuery) selectWithOblast := func(b *testing.B) { - r := must.Return(store.SelectOne(noctx, db, query))(b) + r := must.Return(store.SelectOne(noctx, dbh, query))(b) assert.Equal(b, r.ID, recordID) } selectWithOblastWhere := func(b *testing.B) { - r := must.Return(precomputedQuery.SelectOne(noctx, db))(b) + r := must.Return(precomputedQuery.SelectOne(noctx, dbh))(b) assert.Equal(b, r.ID, recordID) } @@ -256,6 +258,7 @@ func BenchmarkSelectOne(b *testing.B) { func BenchmarkInsertAndDelete(b *testing.B) { db, dsn := makeTestDB(b, 0) + dbh := oblast.Wrap(db) store := oblast.MustNewStore[OblastEntry]( oblast.SqliteDialect(), @@ -277,22 +280,22 @@ func BenchmarkInsertAndDelete(b *testing.B) { records[idx] = OblastEntry{Message: "hello"} recordsForInsert[idx] = &records[idx] } - must.Succeed(b, store.Insert(noctx, db, recordsForInsert...)) + must.Succeed(b, store.Insert(noctx, dbh, recordsForInsert...)) for _, r := range records { if r.ID == 0 { b.Errorf("ID was not filled!") } } - must.Succeed(b, store.Delete(noctx, db, records...)) + must.Succeed(b, store.Delete(noctx, dbh, records...)) } if batchSize == 1 { insertAndDeleteWithOblast = func(b *testing.B) { record := OblastEntry{Message: "hello"} - must.Succeed(b, store.Insert(noctx, db, &record)) + must.Succeed(b, store.Insert(noctx, dbh, &record)) if record.ID == 0 { b.Errorf("ID was not filled!") } - must.Succeed(b, store.Delete(noctx, db, record)) + must.Succeed(b, store.Delete(noctx, dbh, record)) } } @@ -441,6 +444,7 @@ func BenchmarkInsertAndDelete(b *testing.B) { func BenchmarkUpdate(b *testing.B) { db, dsn := makeTestDB(b, 0) + dbh := oblast.Wrap(db) store := oblast.MustNewStore[OblastEntry]( oblast.SqliteDialect(), @@ -462,7 +466,7 @@ func BenchmarkUpdate(b *testing.B) { recordsForOblast[idx] = OblastEntry{Message: "hello"} recordsForOblastForInsert[idx] = &recordsForOblast[idx] } - must.Succeed(b, store.Insert(noctx, db, recordsForOblastForInsert...)) + must.Succeed(b, store.Insert(noctx, dbh, recordsForOblastForInsert...)) recordsForGorp := make([]any, batchSize) for idx, r := range recordsForOblast { recordsForGorp[idx] = new(GorpEntry(r)) @@ -477,7 +481,7 @@ func BenchmarkUpdate(b *testing.B) { for idx := range recordsForOblast { recordsForOblast[idx].Message = message } - must.Succeed(b, store.Update(noctx, db, recordsForOblast...)) + must.Succeed(b, store.Update(noctx, dbh, recordsForOblast...)) } updateWithGorp := func(b *testing.B, message string) { for _, r := range recordsForGorp { diff --git a/handle.go b/handle.go new file mode 100644 index 0000000..b7f8608 --- /dev/null +++ b/handle.go @@ -0,0 +1,90 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net> +// SPDX-License-Identifier: Apache-2.0 + +package oblast + +import ( + "context" + "database/sql" + "fmt" + + "go.xyrillian.de/oblast/handle" +) + +// Handle contains behavior that database handles must offer to Oblast. +// The standard-library types [*sql.DB] and [*sql.Tx] can satisfy this interface through the [Wrap] function. +// Custom implementations of this interface can be used to connect non-std database drivers to Oblast. +type Handle = handle.Handle + +// StdHandle is an interface covered by both [*sql.DB] and [*sql.Tx]. +// It appears in the signature of function [Wrap]. +type StdHandle interface { + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row +} + +// static assertion that the respective types implement the interface +var ( + _ StdHandle = &sql.DB{} + _ StdHandle = &sql.Tx{} +) + +// Wrap converts an [*sql.DB] or [*sql.Tx] into a [Handle] that can be used with Oblast functions. +func Wrap(dbOrTx StdHandle) Handle { + return wrappedHandle{dbOrTx} +} + +type wrappedHandle struct { + db StdHandle +} + +// Prepare implements the [Handle] interface. +func (h wrappedHandle) Prepare(ctx context.Context, query string, repeated bool) (handle.Statement, error) { + if !repeated { + return wrappedStatement{h.db, query, nil}, nil + } + stmt, err := h.db.PrepareContext(ctx, query) + if err != nil { + return nil, fmt.Errorf("during Prepare(): %w", err) + } + return wrappedStatement{h.db, query, stmt}, nil +} + +// Query implements the [Handle] interface. +func (h wrappedHandle) Query(ctx context.Context, query string, args []any) (handle.Rows, error) { + return h.db.QueryContext(ctx, query, args...) //nolint:rowserrcheck // the caller does the check +} + +type wrappedStatement struct { + db StdHandle + query string + stmt *sql.Stmt // nil if repeated = false +} + +// Close implements the [Statement] interface. +func (s wrappedStatement) Close() error { + if s.stmt == nil { + return nil + } + return s.stmt.Close() +} + +// Exec implements the [Statement] interface. +func (s wrappedStatement) Exec(ctx context.Context, args []any) (sql.Result, error) { + if s.stmt == nil { + return s.db.ExecContext(ctx, s.query, args...) + } else { + return s.stmt.ExecContext(ctx, args...) + } +} + +// QueryRow implements the [Statement] interface. +func (s wrappedStatement) QueryRow(ctx context.Context, args, slots []any) error { + if s.stmt == nil { + return s.db.QueryRowContext(ctx, s.query, args...).Scan(slots...) + } else { + return s.stmt.QueryRowContext(ctx, args...).Scan(slots...) + } +} diff --git a/handle/handle.go b/handle/handle.go new file mode 100644 index 0000000..41d82b5 --- /dev/null +++ b/handle/handle.go @@ -0,0 +1,52 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net> +// SPDX-License-Identifier: Apache-2.0 + +// Package handle contains type definitions for connecting non-std database drivers to Oblast. +// Since most database drivers use the standard interface from databse/sql, the Wrap() function from the main package covers the needs of most users. +package handle + +import ( + "context" + "database/sql" +) + +// Handle contains behavior that database handles must offer to Oblast. +// The standard-library types [*sql.DB] and [*sql.Tx] can satisfy this interface through the Wrap() function from the main package. +// Custom implementations of this interface can be used to connect non-std database drivers to Oblast. +type Handle interface { + // Prepare prepares to execute a certain SQL query one or multiple times. + // + // The "repeated" flag is a hint to the implementation whether the same statement is going to be run many times. + // If false, the implementation shall choose to forego the additional effort of a full statement preparation if possible, + // and execute one-off queries instead. + Prepare(ctx context.Context, query string, repeated bool) (Statement, error) + + // Query works like db.QueryContext(ctx, query, args...). + Query(ctx context.Context, query string, args []any) (Rows, error) +} + +// Statement represents a prepared statement returned from [Handle.Prepare]. +// The Exec and QueryRow methods shall work similarly to the respective functions on [*sql.Tx], as indicated in the comments. +// +// You will not need to interact with this type except when implementing your own [Handle]. +type Statement interface { + Close() error + + // Exec works like stmt.ExecContext(ctx, args...). + Exec(ctx context.Context, args []any) (sql.Result, error) + + // QueryRow works like stmt.QueryRow(ctx, args...).Scan(slots...). + QueryRow(ctx context.Context, args []any, slots []any) error +} + +// Rows represents a set of rows returned from [Handle.Query] in response to a DB query. +// All methods shall behave like on the [*sql.Rows] type from std. +// +// You will not need to interact with this type except when implementing your own [Handle]. +type Rows interface { + Columns() ([]string, error) + Close() error + Err() error + Next() bool + Scan(slots ...any) error +} @@ -96,7 +96,6 @@ package oblast // import "go.xyrillian.de/oblast" import ( - "context" "database/sql" "database/sql/driver" "fmt" @@ -132,25 +131,6 @@ func StructTagKeyIs(key string) PlanOption { return func(opts *planOpts) { opts.StructTagKey = key } } -// Handle is an interface for functions providing direct DB access. -// It covers methods provided by both *sql.DB and *sql.Tx, thus allowing functions using it to be used both within and outside of transactions. -type Handle interface { - ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) - PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) - QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) - QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row -} - -// TODO: investigate if we can extend type Handle to cover types github.com/jackc/pgx.{Conn,Tx} -// - those have all these methods, but with different return types that act mostly in the same way -// - a significant departure is that their Prepare() works wildly differently - -// static assertion that the respective types implement the interface -var ( - _ Handle = &sql.DB{} - _ Handle = &sql.Tx{} -) - // Store is the main interface of this library. // // It holds information on how to read and write data into record type R, @@ -5,9 +5,10 @@ package oblast import ( "context" - "database/sql" "fmt" "reflect" + + "go.xyrillian.de/oblast/handle" ) // PrepareThreshold is a tuning parameter for the strategy used by all methods of [Store] operating on batches of records provided by the caller @@ -20,53 +21,17 @@ import ( // This tuning parameter defines the minimum number of records that will justify maintaining a prepared statement. // Our benchmarking with the mattn/go-sqlite3 driver (and last checked with Go 1.26.2 on x86_64) indicates that this becomes a worthwhile investment at 8 or more records, so this is our default. // If your benchmarking indicates a different tradeoff depending on your choice of Go version or SQL driver, you may adjust this variable accordingly. +// +// The actual effect of this setting is to control the value of the "repeated" argument in [Handle.Prepare]. var PrepareThreshold int = 8 -// preparedStatement behaves like sql.Stmt, but only uses *sql.Stmt when it is useful (see explanation above). -type preparedStatement struct { - db Handle - query string - stmt *sql.Stmt // nil for input sizes below PrepareThreshold -} - // prepare behaves like [Handle.Prepare]. -func prepare(ctx context.Context, db Handle, query, operation string, inputSize int) (preparedStatement, error) { +func prepare(ctx context.Context, db Handle, query, operation string, inputSize int) (handle.Statement, error) { if query == "" { - return preparedStatement{}, fmt.Errorf("cannot execute %s() because query could not be autogenerated", operation) - } - - if inputSize < PrepareThreshold { - return preparedStatement{db, query, nil}, nil - } - stmt, err := db.PrepareContext(ctx, query) - if err != nil { - return preparedStatement{}, fmt.Errorf("during Prepare(): %w", err) + return nil, fmt.Errorf("cannot execute %s() because query could not be autogenerated", operation) } - return preparedStatement{db, query, stmt}, nil -} - -// Close behaves like [sql.Stmt.Close]. -func (s preparedStatement) Close() error { - if s.stmt == nil { - return nil - } - return s.stmt.Close() -} -// ExecContext behaves like [sql.Stmt.ExecContext]. -func (s preparedStatement) ExecContext(ctx context.Context, args ...any) (sql.Result, error) { - if s.stmt == nil { - return s.db.ExecContext(ctx, s.query, args...) - } - return s.stmt.ExecContext(ctx, args...) -} - -// QueryRow behaves like [sql.Stmt.QueryRowContext]. -func (s preparedStatement) QueryRowContext(ctx context.Context, args ...any) *sql.Row { - if s.stmt == nil { - return s.db.QueryRowContext(ctx, s.query, args...) - } - return s.stmt.QueryRowContext(ctx, args...) + return db.Prepare(ctx, query, inputSize >= PrepareThreshold) } // Insert executes an SQL INSERT statement for each of the provided records. @@ -89,7 +54,7 @@ func (s Store[R]) Insert(ctx context.Context, db Handle, records ...*R) error { return s.insertUsing(ctx, stmt, db, records) } -func (s Store[R]) insertUsing(ctx context.Context, stmt preparedStatement, db Handle, records []*R) error { +func (s Store[R]) insertUsing(ctx context.Context, stmt handle.Statement, db Handle, records []*R) error { // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. // Any expression that does not depend on type R should be factored out into a reusable function. @@ -111,7 +76,7 @@ func (s Store[R]) insertUsing(ctx context.Context, stmt preparedStatement, db Ha return newIOError(nil, "Stmt.Close", stmt.Close()) } -func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt preparedStatement, argumentIndexes [][]int, argumentSlots []any, scanIndexes [][]int, scanSlots []any) error { +func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any, scanIndexes [][]int, scanSlots []any) error { for idx, index := range argumentIndexes { argumentSlots[idx] = v.FieldByIndex(index).Interface() } @@ -124,10 +89,10 @@ func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt pr } var err error if len(scanSlots) == 0 { - _, err = stmt.ExecContext(ctx, argumentSlots...) + _, err = stmt.Exec(ctx, argumentSlots) } else { // TODO: using QueryRow for inserting is extremely expensive because database/sql allocates a Rows instance under the hood; other libraries are doing better by limiting themselves to ExecContext() + LastInsertId() - err = stmt.QueryRowContext(ctx, argumentSlots...).Scan(scanSlots...) + err = stmt.QueryRow(ctx, argumentSlots, scanSlots) } if err != nil { return fmt.Errorf("while inserting record with idx = %d: %w", recordIndex, err) @@ -166,11 +131,11 @@ func (s Store[R]) Update(ctx context.Context, db Handle, records ...R) error { return newIOError(nil, "Stmt.Close", stmt.Close()) } -func updateRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt preparedStatement, argumentIndexes [][]int, argumentSlots []any) (int64, error) { +func updateRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any) (int64, error) { for idx, index := range argumentIndexes { argumentSlots[idx] = v.FieldByIndex(index).Interface() } - result, err := stmt.ExecContext(ctx, argumentSlots...) + result, err := stmt.Exec(ctx, argumentSlots) if err != nil { return 0, fmt.Errorf("while updating record with idx = %d: %w", recordIndex, err) } @@ -209,11 +174,11 @@ func (s Store[R]) Delete(ctx context.Context, db Handle, records ...R) error { return newIOError(nil, "Stmt.Close", stmt.Close()) } -func deleteRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt preparedStatement, argumentIndexes [][]int, argumentSlots []any) error { +func deleteRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any) error { for idx, index := range argumentIndexes { argumentSlots[idx] = v.FieldByIndex(index).Interface() } - _, err := stmt.ExecContext(ctx, argumentSlots...) + _, err := stmt.Exec(ctx, argumentSlots) if err != nil { return fmt.Errorf("while deleting record with idx = %d: %w", recordIndex, err) } diff --git a/query_test.go b/query_test.go index 05a3af2..41b0203 100644 --- a/query_test.go +++ b/query_test.go @@ -18,7 +18,7 @@ import ( func TestInsertBasic(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `oblast:"id,auto"` @@ -52,7 +52,7 @@ func TestInsertBasic(t *testing.T) { func TestUpdateBasic(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `db:"id,auto"` @@ -82,7 +82,7 @@ func TestUpdateBasic(t *testing.T) { func TestDeleteBasic(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `db:"id,auto"` @@ -112,7 +112,7 @@ func TestDeleteBasic(t *testing.T) { func TestUpsertBasicWithAutoColumn(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `db:"id,auto"` @@ -158,7 +158,7 @@ func TestUpsertBasicWithAutoColumn(t *testing.T) { func TestWriteQueriesNotPossible(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `db:"id,auto"` @@ -187,7 +187,7 @@ func TestWriteQueriesNotPossible(t *testing.T) { func TestWriteQueriesFailDuringPrepare(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `db:"id,auto"` @@ -236,7 +236,7 @@ func TestWriteQueriesFailDuringPrepare(t *testing.T) { func TestUpdateOrUpsertFailsOnMissingRecord(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `db:"id,auto"` @@ -271,7 +271,7 @@ func TestUpdateOrUpsertFailsOnMissingRecord(t *testing.T) { func TestInsertFailsOnFilledAutoField(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `db:"id,auto"` @@ -294,7 +294,7 @@ func TestInsertFailsOnFilledAutoField(t *testing.T) { func TestInsertAndUpsertWithNoAutoColumns(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type relation struct { FooID int64 `db:"foo_id"` @@ -325,7 +325,7 @@ func TestInsertAndUpsertWithNoAutoColumns(t *testing.T) { func TestUpsertFailsOnMixedAutoFieldState(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type complexRecord struct { ID int64 `db:"id,auto"` diff --git a/runtimeindex_test.go b/runtimeindex_test.go index ba16fd9..8e0b68f 100644 --- a/runtimeindex_test.go +++ b/runtimeindex_test.go @@ -16,7 +16,7 @@ import ( func TestRuntimeIndex(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `db:"id"` @@ -11,6 +11,7 @@ import ( "reflect" . "go.xyrillian.de/gg/option" + "go.xyrillian.de/oblast/handle" ) // Select executes the provided SQL query and fills an instance of the record type R for each row in the result set, @@ -103,8 +104,8 @@ func (s Store[R]) SelectWhere(ctx context.Context, db Handle, partialQuery strin return result, newIOError(err, "Rows.Err", rows.Err()) } -func startSelectQuery(ctx context.Context, db Handle, plan plan, query string, args ...any) (*sql.Rows, [][]int, error) { - rows, err := db.QueryContext(ctx, query, args...) +func startSelectQuery(ctx context.Context, db Handle, plan plan, query string, args ...any) (handle.Rows, [][]int, error) { + rows, err := db.Query(ctx, query, args) if err != nil { return nil, nil, fmt.Errorf("during Query(): %w", err) } @@ -130,19 +131,19 @@ func startSelectQuery(ctx context.Context, db Handle, plan plan, query string, a return rows, indexes, nil } -func startSelectWhereQuery(ctx context.Context, db Handle, plan plan, partialQuery string, args ...any) (rows *sql.Rows, indexes [][]int, err error) { +func startSelectWhereQuery(ctx context.Context, db Handle, plan plan, partialQuery string, args ...any) (rows handle.Rows, indexes [][]int, err error) { if plan.Select.Query == "" { return nil, nil, errors.New("cannot execute SelectWhere() because query could not be autogenerated") } query := plan.Select.Query + partialQuery - rows, err = db.QueryContext(ctx, query, args...) + rows, err = db.Query(ctx, query, args) if err != nil { err = fmt.Errorf("during Query(): %w", err) } return rows, plan.Select.ScanIndexes, err } -func collectRow(rows *sql.Rows, plan plan, v reflect.Value, slots []any, indexes [][]int) error { +func collectRow(rows handle.Rows, plan plan, v reflect.Value, slots []any, indexes [][]int) error { for _, index := range plan.IndexesOfTransparentPointerStructs { f := v.FieldByIndex(index) f.Set(reflect.New(f.Type().Elem())) @@ -239,7 +240,12 @@ func selectOne(ctx context.Context, db Handle, plan plan, v reflect.Value, query for idx, index := range plan.Select.ScanIndexes { slots[idx] = v.FieldByIndex(index).Addr().Interface() } - return db.QueryRowContext(ctx, query, args...).Scan(slots...) + stmt, err := db.Prepare(ctx, query, false) + if err != nil { + return err + } + err = stmt.QueryRow(ctx, args, slots) + return newIOError(err, "Stmt.Close", stmt.Close()) } func noRowsToNone[R any](record R, err error) (Option[R], error) { diff --git a/select_test.go b/select_test.go index bd56f7d..7b4191a 100644 --- a/select_test.go +++ b/select_test.go @@ -20,7 +20,7 @@ import ( func TestSelectReturningSomeRecords(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `db:"id"` @@ -138,7 +138,7 @@ func TestSelectReturningSomeRecords(t *testing.T) { func TestSelectReturningNoRecords(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `db:"id"` @@ -229,7 +229,7 @@ func TestSelectReturningNoRecords(t *testing.T) { func TestSelectIntoUnexpectedField(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `db:"id"` @@ -268,7 +268,7 @@ func TestSelectIntoUnexpectedField(t *testing.T) { func TestSelectWithScanError(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `db:"id"` @@ -331,7 +331,7 @@ func TestSelectWithScanError(t *testing.T) { func TestSelectIntoEmbeddedTypes(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type HasCreatedAt struct { CreatedAt time.Time `db:"created_at"` @@ -442,7 +442,7 @@ func TestSelectIntoEmbeddedTypes(t *testing.T) { func TestSelectCapturingQueryError(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `db:"id"` @@ -490,7 +490,7 @@ func TestSelectCapturingQueryError(t *testing.T) { func TestSelectCapturingCloseError(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `db:"id"` @@ -553,7 +553,7 @@ func TestSelectCapturingCloseError(t *testing.T) { func TestSelectNotPossibleWithoutTableName(t *testing.T) { ctx := t.Context() md := mock.NewDriver() - db := sql.OpenDB(md) + db := oblast.Wrap(sql.OpenDB(md)) type basicRecord struct { ID int64 `db:"id"` |
