aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStefan Majewsky <majewsky@gmx.net>2026-04-14 00:41:25 +0200
committerStefan Majewsky <majewsky@gmx.net>2026-04-14 00:41:25 +0200
commit9191e018ff90deb99f3881966a5d356a05027e0f (patch)
treec36880ed2c0755132306141e61c8073d2926b0de
parent49a52b73afac2c97a8f3b7cffd434b29e6f30fcf (diff)
downloadgo-oblast-9191e018ff90deb99f3881966a5d356a05027e0f.tar.gz
initial test coverage for Store.Select functions
-rw-r--r--internal/assert/assert.go13
-rw-r--r--internal/mock/driver.go20
-rw-r--r--internal/must/must.go2
-rw-r--r--internal/plan.go2
-rw-r--r--oblast.go4
-rw-r--r--select.go31
-rw-r--r--select_test.go222
7 files changed, 263 insertions, 31 deletions
diff --git a/internal/assert/assert.go b/internal/assert/assert.go
index c82d35c..84b6ecf 100644
--- a/internal/assert/assert.go
+++ b/internal/assert/assert.go
@@ -23,3 +23,16 @@ func DeepEqual[V any](t testing.TB, actual, expected V) {
t.Errorf("expected %#v, but got %#v", expected, actual)
}
}
+
+// SliceEqual is a test assertion.
+func SliceEqual[V comparable](t testing.TB, actual []V, expected ...V) {
+ t.Helper()
+ if len(actual) != len(expected) {
+ t.Errorf("length mismatch: expected %#v, but got %#v", expected, actual)
+ }
+ for idx := range actual {
+ if actual[idx] != expected[idx] {
+ t.Errorf("element %d: expected %#v, but got %#v", idx, expected[idx], actual[idx])
+ }
+ }
+}
diff --git a/internal/mock/driver.go b/internal/mock/driver.go
index 4183097..d3358c4 100644
--- a/internal/mock/driver.go
+++ b/internal/mock/driver.go
@@ -75,22 +75,26 @@ func newExpectation[T any](args []any) expectation[T] {
output: new(T),
}
for idx, arg := range args {
- e.args[idx] = arg
+ var err error
+ e.args[idx], err = driver.DefaultParameterConverter.ConvertValue(arg)
+ if err != nil {
+ panic(fmt.Sprintf("could not convert value %#v into driver.Value: %s", arg, err.Error()))
+ }
}
return e
}
-// ExpectExec plans a response to an Exec() call.
-func (rs *ResponseSet) ExpectExec(args ...any) *Result {
+// ExpectExecWithArgs plans a response to an Exec() call.
+func (rs *ResponseSet) ExpectExecWithArgs(args ...any) *Result {
e := newExpectation[Result](args)
rs.expectedExecs = append(rs.expectedExecs, e)
return e.output
}
-// ExpectQuery plans a response to a Query() or QueryRows() call.
-func (rs *ResponseSet) ExpectQuery(args ...any) *Result {
- e := newExpectation[Result](args)
- rs.expectedExecs = append(rs.expectedExecs, e)
+// ExpectQueryWithArgs plans a response to a Query() or QueryRows() call.
+func (rs *ResponseSet) ExpectQueryWithArgs(args ...any) *Rows {
+ e := newExpectation[Rows](args)
+ rs.expectedQueries = append(rs.expectedQueries, e)
return e.output
}
@@ -258,7 +262,7 @@ func (r *Rows) AndReturnColumns(columns ...string) *Rows {
// WithRow adds a row to the result set that will be returned by this query.
// This may only be called after AndReturnColumns().
func (r *Rows) WithRow(values ...any) *Rows {
- if len(r.columns) != 0 {
+ if len(r.columns) == 0 {
panic("AndReturnColumns() has not been called for this Rows object yet")
}
if len(r.columns) != len(values) {
diff --git a/internal/must/must.go b/internal/must/must.go
index e472579..7a137c6 100644
--- a/internal/must/must.go
+++ b/internal/must/must.go
@@ -7,6 +7,7 @@ import "testing"
// Succeed fails the test if err is not nil.
func Succeed(t testing.TB, err error) {
+ t.Helper()
if err != nil {
t.Fatal(err.Error())
}
@@ -16,6 +17,7 @@ func Succeed(t testing.TB, err error) {
// and either forwards the result value on success, or fails the test on error.
func Return[V any](value V, err error) func(testing.TB) V {
return func(t testing.TB) V {
+ t.Helper()
if err != nil {
t.Fatal(err.Error())
}
diff --git a/internal/plan.go b/internal/plan.go
index f619a5f..b57b8dd 100644
--- a/internal/plan.go
+++ b/internal/plan.go
@@ -14,6 +14,7 @@ import (
// Plan holds all information that we can derive from reflecting on a given type.
// The queries held within are only valid within the context of a given SQL dialect.
type Plan struct {
+ TypeName string // for use in error messages
TableName string // from info.TableNameIs marker (if any)
AllColumnNames []string // in order of struct fields
PrimaryKeyColumnNames []string // from info.PrimaryKeyIs marker (if any)
@@ -64,6 +65,7 @@ func buildPlan(t reflect.Type, dialect Dialect, opts PlanOpts) (Plan, error) {
}
var p = Plan{
+ TypeName: t.Name(),
TableName: opts.TableName,
PrimaryKeyColumnNames: opts.PrimaryKeyColumnNames,
IndexByColumnName: make(map[string][]int),
diff --git a/oblast.go b/oblast.go
index 5a46042..15f840a 100644
--- a/oblast.go
+++ b/oblast.go
@@ -42,7 +42,6 @@ package oblast // import "go.xyrillian.de/oblast"
import (
"database/sql"
- "errors"
"reflect"
"go.xyrillian.de/oblast/internal"
@@ -78,9 +77,6 @@ var (
_ Handle = &sql.Tx{}
)
-// ErrMultipleRows is returned by [Store.SelectOne] if the query returned multiple rows.
-var ErrMultipleRows = errors.New("sql: multiple rows in result set")
-
// Store is the main interface of this library.
//
// It holds information on how to read and write data into record type R,
diff --git a/select.go b/select.go
index d839c13..e6eccb1 100644
--- a/select.go
+++ b/select.go
@@ -79,16 +79,16 @@ func (s Store[R]) SelectWhere(db Handle, partialQuery string, args ...any) (resu
return result, nil
}
-func startSelectQuery(db Handle, plan internal.Plan, query string, args ...any) (rows *sql.Rows, indexes [][]int, err error) {
- rows, err = db.Query(query, args...)
+func startSelectQuery(db Handle, plan internal.Plan, query string, args ...any) (returnedRows *sql.Rows, indexes [][]int, returnedError error) {
+ rows, err := db.Query(query, args...)
if err != nil {
return nil, nil, fmt.Errorf("during Query(): %w", err)
}
defer func() {
- if err != nil {
- closeErr := rows.Close()
+ if returnedError != nil {
+ closeErr := rows.Close() // NOTE: Not `returnedRows.Close()`! We may have `rows != nil && returnedRows == nil`.
if closeErr != nil {
- err = fmt.Errorf("%w (additional error during rows.Close(): %s)", err, closeErr.Error())
+ returnedError = fmt.Errorf("%w (additional error during rows.Close(): %s)", returnedError, closeErr.Error())
}
}
}()
@@ -103,8 +103,8 @@ func startSelectQuery(db Handle, plan internal.Plan, query string, args ...any)
indexes[idx], ok = plan.IndexByColumnName[columnName]
if !ok {
return nil, nil, fmt.Errorf(
- "result has column %q in position %d, but no field in record type has `db:%[1]q`",
- columnName, idx,
+ "result has column %q in position %d, but no field in type %s has `db:%[1]q`",
+ columnName, idx, plan.TypeName,
)
}
}
@@ -125,11 +125,7 @@ func collectRow(rows *sql.Rows, v reflect.Value, slots []any, indexes [][]int) e
for idx, index := range indexes {
slots[idx] = v.FieldByIndex(index).Addr().Interface()
}
- err := rows.Scan(slots...)
- if err != nil {
- return fmt.Errorf("during rows.Scan(): %w", err)
- }
- return nil
+ return rows.Scan(slots...)
}
func mergeCloseError(typeName string, err, closeErr error) error {
@@ -147,7 +143,6 @@ func mergeCloseError(typeName string, err, closeErr error) error {
// according to the column names reported by the database as part of the result set.
//
// If there are no rows in the result set, [sql.ErrNoRows] is returned.
-// If there are multiple rows in the result set, [ErrMultipleRows] is returned.
//
// Warning: Because of limitations in the interface of database/sql, this function is built on [Store.Select] and cannot be any faster than it.
// For maximum performance, use [Store.SelectOneWhere] which avoids the overhead of potentially having to read multiple rows.
@@ -158,13 +153,10 @@ func (s Store[R]) SelectOne(db Handle, query string, args ...any) (result R, err
var results []R
results, err = s.Select(db, query, args...)
if err == nil {
- switch len(results) {
- case 0:
+ if len(results) == 0 {
err = sql.ErrNoRows
- case 1:
+ } else {
result = results[0]
- default:
- err = ErrMultipleRows
}
}
return
@@ -173,7 +165,8 @@ func (s Store[R]) SelectOne(db Handle, query string, args ...any) (result R, err
// SelectOneWhere is like [Store.SelectOne], but you only provide the part of the SELECT query that comes after the WHERE.
// See [Store.SelectWhere] for an explanation of how the full query is constructed from this partial query.
//
-// This method is significantly more efficient than [Store.SelectWhere] and using it is recommended when possible.
+// This method is significantly more efficient than [Store.SelectOne].
+// Prefer using it instaed of [Store.SelectOne] whenever possible.
func (s Store[R]) SelectOneWhere(db Handle, partialQuery string, args ...any) (result R, err error) {
// NOTE: This function body should be as short as possible to reduce the binary size after monomorphization.
// Any expression that does not depend on type R should be factored out into a reusable function.
diff --git a/select_test.go b/select_test.go
new file mode 100644
index 0000000..d678aa2
--- /dev/null
+++ b/select_test.go
@@ -0,0 +1,222 @@
+// SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net>
+// SPDX-License-Identifier: Apache-2.0
+
+package oblast_test
+
+import (
+ "database/sql"
+ "testing"
+ "time"
+
+ "go.xyrillian.de/oblast"
+ "go.xyrillian.de/oblast/internal/assert"
+ "go.xyrillian.de/oblast/internal/mock"
+ "go.xyrillian.de/oblast/internal/must"
+)
+
+func TestSelectReturningSomeRecords(t *testing.T) {
+ md := mock.NewDriver()
+ db := sql.OpenDB(md)
+
+ type basicRecord struct {
+ ID int64 `db:"id"`
+ Name string `db:"name"`
+ }
+ store := must.Return(oblast.NewStore[basicRecord](
+ oblast.SqliteDialect(),
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ ))(t)
+
+ t.Run("using Store.Select", func(t *testing.T) {
+ md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`).
+ ExpectQueryWithArgs(3).
+ AndReturnColumns("name", "id").
+ WithRow("foo", 1).
+ WithRow("bar", 2)
+ records := must.Return(store.Select(db, `SELECT * FROM basic_records WHERE id < ?`, 3))(t)
+ assert.SliceEqual(t, records,
+ basicRecord{1, "foo"},
+ basicRecord{2, "bar"},
+ )
+ })
+
+ t.Run("using Store.SelectWhere", func(t *testing.T) {
+ md.ForQuery(`SELECT "id", "name" FROM "basic_records" WHERE id < ?`).
+ ExpectQueryWithArgs(3).
+ AndReturnColumns("id", "name").
+ WithRow(1, "ffoo").
+ WithRow(2, "bbar")
+ records := must.Return(store.SelectWhere(db, `id < ?`, 3))(t)
+ assert.SliceEqual(t, records,
+ basicRecord{1, "ffoo"},
+ basicRecord{2, "bbar"},
+ )
+ })
+
+ t.Run("using Store.SelectOne", func(t *testing.T) {
+ md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`).
+ ExpectQueryWithArgs(3).
+ AndReturnColumns("name", "id").
+ WithRow("fffoo", 1).
+ WithRow("bbbar", 2)
+ record := must.Return(store.SelectOne(db, `SELECT * FROM basic_records WHERE id < ?`, 3))(t)
+ assert.Equal(t, record, basicRecord{1, "fffoo"})
+ })
+
+ t.Run("using Store.SelectOneWhere", func(t *testing.T) {
+ md.ForQuery(`SELECT "id", "name" FROM "basic_records" WHERE id < ?`).
+ ExpectQueryWithArgs(3).
+ AndReturnColumns("id", "name").
+ WithRow(1, "ffffoo").
+ WithRow(2, "bbbbar")
+ record := must.Return(store.SelectOneWhere(db, `id < ?`, 3))(t)
+ assert.Equal(t, record, basicRecord{1, "ffffoo"})
+ })
+}
+
+func TestSelectReturningNoRecords(t *testing.T) {
+ md := mock.NewDriver()
+ db := sql.OpenDB(md)
+
+ type basicRecord struct {
+ ID int64 `db:"id"`
+ Name string `db:"name"`
+ }
+ store := must.Return(oblast.NewStore[basicRecord](
+ oblast.SqliteDialect(),
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ ))(t)
+
+ t.Run("using Store.Select", func(t *testing.T) {
+ md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`).
+ ExpectQueryWithArgs(3).
+ AndReturnColumns("name", "id")
+ records := must.Return(store.Select(db, `SELECT * FROM basic_records WHERE id < ?`, 3))(t)
+ assert.SliceEqual(t, records, nil...)
+ })
+
+ t.Run("using Store.SelectWhere", func(t *testing.T) {
+ md.ForQuery(`SELECT "id", "name" FROM "basic_records" WHERE id < ?`).
+ ExpectQueryWithArgs(3).
+ AndReturnColumns("id", "name")
+ records := must.Return(store.SelectWhere(db, `id < ?`, 3))(t)
+ assert.SliceEqual(t, records, nil...)
+ })
+
+ t.Run("using Store.SelectOne", func(t *testing.T) {
+ md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`).
+ ExpectQueryWithArgs(3).
+ AndReturnColumns("name", "id")
+ _, err := store.SelectOne(db, `SELECT * FROM basic_records WHERE id < ?`, 3)
+ assert.Equal(t, err.Error(), sql.ErrNoRows.Error())
+ })
+
+ t.Run("using Store.SelectOneWhere", func(t *testing.T) {
+ md.ForQuery(`SELECT "id", "name" FROM "basic_records" WHERE id < ?`).
+ ExpectQueryWithArgs(3).
+ AndReturnColumns("id", "name")
+ _, err := store.SelectOneWhere(db, `id < ?`, 3)
+ assert.Equal(t, err.Error(), sql.ErrNoRows.Error())
+ })
+}
+
+func TestSelectIntoUnexpectedField(t *testing.T) {
+ md := mock.NewDriver()
+ db := sql.OpenDB(md)
+
+ type basicRecord struct {
+ ID int64 `db:"id"`
+ Description string `db:"desc"` // but DB knows only the field "name"!
+ }
+ store := must.Return(oblast.NewStore[basicRecord](
+ oblast.SqliteDialect(),
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ ))(t)
+
+ expectedError := "result has column \"name\" in position 0, but no field in type basicRecord has `db:\"name\"`"
+
+ // NOTE: This problem cannot occur with SelectWhere() and SelectOneWhere() because of their use of query generation.
+
+ t.Run("using Store.Select", func(t *testing.T) {
+ md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`).
+ ExpectQueryWithArgs(3).
+ AndReturnColumns("name", "id").
+ WithRow("foo", 1).
+ WithRow("bar", 2)
+ _, err := store.Select(db, `SELECT * FROM basic_records WHERE id < ?`, 3)
+ assert.Equal(t, err.Error(), expectedError)
+ })
+
+ t.Run("using Store.SelectOne", func(t *testing.T) {
+ md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`).
+ ExpectQueryWithArgs(3).
+ AndReturnColumns("name", "id").
+ WithRow("ffoo", 1).
+ WithRow("bbar", 2)
+ _, err := store.SelectOne(db, `SELECT * FROM basic_records WHERE id < ?`, 3)
+ assert.Equal(t, err.Error(), expectedError)
+ })
+}
+
+func TestSelectWithScanError(t *testing.T) {
+ md := mock.NewDriver()
+ db := sql.OpenDB(md)
+
+ type basicRecord struct {
+ ID int64 `db:"id"`
+ CreatedAt time.Time `db:"created_at"` // but the DB will give us strings that are not timestamps
+ }
+ store := must.Return(oblast.NewStore[basicRecord](
+ oblast.SqliteDialect(),
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ ))(t)
+
+ expectedError := `sql: Scan error on column index 1, name "created_at": unsupported Scan, storing driver.Value type string into type *time.Time`
+
+ t.Run("using Store.Select", func(t *testing.T) {
+ md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`).
+ ExpectQueryWithArgs(3).
+ AndReturnColumns("id", "created_at").
+ WithRow(1, "foo").
+ WithRow(2, "bar")
+ _, err := store.Select(db, `SELECT * FROM basic_records WHERE id < ?`, 3)
+ assert.Equal(t, err.Error(), expectedError)
+ })
+
+ t.Run("using Store.SelectWhere", func(t *testing.T) {
+ md.ForQuery(`SELECT "id", "created_at" FROM "basic_records" WHERE id < ?`).
+ ExpectQueryWithArgs(3).
+ AndReturnColumns("id", "created_at").
+ WithRow(1, "ffoo").
+ WithRow(2, "bbar")
+ _, err := store.SelectWhere(db, `id < ?`, 3)
+ assert.Equal(t, err.Error(), expectedError)
+ })
+
+ t.Run("using Store.SelectOne", func(t *testing.T) {
+ md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`).
+ ExpectQueryWithArgs(3).
+ AndReturnColumns("id", "created_at").
+ WithRow(1, "fffoo").
+ WithRow(2, "bbbar")
+ _, err := store.SelectOne(db, `SELECT * FROM basic_records WHERE id < ?`, 3)
+ assert.Equal(t, err.Error(), expectedError)
+ })
+
+ t.Run("using Store.SelectOneWhere", func(t *testing.T) {
+ md.ForQuery(`SELECT "id", "created_at" FROM "basic_records" WHERE id < ?`).
+ ExpectQueryWithArgs(3).
+ AndReturnColumns("id", "created_at").
+ WithRow(1, "ffffoo").
+ WithRow(2, "bbbbar")
+ _, err := store.SelectOneWhere(db, `id < ?`, 3)
+ assert.Equal(t, err.Error(), expectedError)
+ })
+}
+
+// TODO: test error capture during Rows.Close()
+// TODO: check for maximum test coverage in select.go