diff options
| author | Stefan Majewsky <majewsky@gmx.net> | 2026-04-14 00:41:25 +0200 |
|---|---|---|
| committer | Stefan Majewsky <majewsky@gmx.net> | 2026-04-14 00:41:25 +0200 |
| commit | 9191e018ff90deb99f3881966a5d356a05027e0f (patch) | |
| tree | c36880ed2c0755132306141e61c8073d2926b0de | |
| parent | 49a52b73afac2c97a8f3b7cffd434b29e6f30fcf (diff) | |
| download | go-oblast-9191e018ff90deb99f3881966a5d356a05027e0f.tar.gz | |
initial test coverage for Store.Select functions
| -rw-r--r-- | internal/assert/assert.go | 13 | ||||
| -rw-r--r-- | internal/mock/driver.go | 20 | ||||
| -rw-r--r-- | internal/must/must.go | 2 | ||||
| -rw-r--r-- | internal/plan.go | 2 | ||||
| -rw-r--r-- | oblast.go | 4 | ||||
| -rw-r--r-- | select.go | 31 | ||||
| -rw-r--r-- | select_test.go | 222 |
7 files changed, 263 insertions, 31 deletions
diff --git a/internal/assert/assert.go b/internal/assert/assert.go index c82d35c..84b6ecf 100644 --- a/internal/assert/assert.go +++ b/internal/assert/assert.go @@ -23,3 +23,16 @@ func DeepEqual[V any](t testing.TB, actual, expected V) { t.Errorf("expected %#v, but got %#v", expected, actual) } } + +// SliceEqual is a test assertion. +func SliceEqual[V comparable](t testing.TB, actual []V, expected ...V) { + t.Helper() + if len(actual) != len(expected) { + t.Errorf("length mismatch: expected %#v, but got %#v", expected, actual) + } + for idx := range actual { + if actual[idx] != expected[idx] { + t.Errorf("element %d: expected %#v, but got %#v", idx, expected[idx], actual[idx]) + } + } +} diff --git a/internal/mock/driver.go b/internal/mock/driver.go index 4183097..d3358c4 100644 --- a/internal/mock/driver.go +++ b/internal/mock/driver.go @@ -75,22 +75,26 @@ func newExpectation[T any](args []any) expectation[T] { output: new(T), } for idx, arg := range args { - e.args[idx] = arg + var err error + e.args[idx], err = driver.DefaultParameterConverter.ConvertValue(arg) + if err != nil { + panic(fmt.Sprintf("could not convert value %#v into driver.Value: %s", arg, err.Error())) + } } return e } -// ExpectExec plans a response to an Exec() call. -func (rs *ResponseSet) ExpectExec(args ...any) *Result { +// ExpectExecWithArgs plans a response to an Exec() call. +func (rs *ResponseSet) ExpectExecWithArgs(args ...any) *Result { e := newExpectation[Result](args) rs.expectedExecs = append(rs.expectedExecs, e) return e.output } -// ExpectQuery plans a response to a Query() or QueryRows() call. -func (rs *ResponseSet) ExpectQuery(args ...any) *Result { - e := newExpectation[Result](args) - rs.expectedExecs = append(rs.expectedExecs, e) +// ExpectQueryWithArgs plans a response to a Query() or QueryRows() call. +func (rs *ResponseSet) ExpectQueryWithArgs(args ...any) *Rows { + e := newExpectation[Rows](args) + rs.expectedQueries = append(rs.expectedQueries, e) return e.output } @@ -258,7 +262,7 @@ func (r *Rows) AndReturnColumns(columns ...string) *Rows { // WithRow adds a row to the result set that will be returned by this query. // This may only be called after AndReturnColumns(). func (r *Rows) WithRow(values ...any) *Rows { - if len(r.columns) != 0 { + if len(r.columns) == 0 { panic("AndReturnColumns() has not been called for this Rows object yet") } if len(r.columns) != len(values) { diff --git a/internal/must/must.go b/internal/must/must.go index e472579..7a137c6 100644 --- a/internal/must/must.go +++ b/internal/must/must.go @@ -7,6 +7,7 @@ import "testing" // Succeed fails the test if err is not nil. func Succeed(t testing.TB, err error) { + t.Helper() if err != nil { t.Fatal(err.Error()) } @@ -16,6 +17,7 @@ func Succeed(t testing.TB, err error) { // and either forwards the result value on success, or fails the test on error. func Return[V any](value V, err error) func(testing.TB) V { return func(t testing.TB) V { + t.Helper() if err != nil { t.Fatal(err.Error()) } diff --git a/internal/plan.go b/internal/plan.go index f619a5f..b57b8dd 100644 --- a/internal/plan.go +++ b/internal/plan.go @@ -14,6 +14,7 @@ import ( // Plan holds all information that we can derive from reflecting on a given type. // The queries held within are only valid within the context of a given SQL dialect. type Plan struct { + TypeName string // for use in error messages TableName string // from info.TableNameIs marker (if any) AllColumnNames []string // in order of struct fields PrimaryKeyColumnNames []string // from info.PrimaryKeyIs marker (if any) @@ -64,6 +65,7 @@ func buildPlan(t reflect.Type, dialect Dialect, opts PlanOpts) (Plan, error) { } var p = Plan{ + TypeName: t.Name(), TableName: opts.TableName, PrimaryKeyColumnNames: opts.PrimaryKeyColumnNames, IndexByColumnName: make(map[string][]int), @@ -42,7 +42,6 @@ package oblast // import "go.xyrillian.de/oblast" import ( "database/sql" - "errors" "reflect" "go.xyrillian.de/oblast/internal" @@ -78,9 +77,6 @@ var ( _ Handle = &sql.Tx{} ) -// ErrMultipleRows is returned by [Store.SelectOne] if the query returned multiple rows. -var ErrMultipleRows = errors.New("sql: multiple rows in result set") - // Store is the main interface of this library. // // It holds information on how to read and write data into record type R, @@ -79,16 +79,16 @@ func (s Store[R]) SelectWhere(db Handle, partialQuery string, args ...any) (resu return result, nil } -func startSelectQuery(db Handle, plan internal.Plan, query string, args ...any) (rows *sql.Rows, indexes [][]int, err error) { - rows, err = db.Query(query, args...) +func startSelectQuery(db Handle, plan internal.Plan, query string, args ...any) (returnedRows *sql.Rows, indexes [][]int, returnedError error) { + rows, err := db.Query(query, args...) if err != nil { return nil, nil, fmt.Errorf("during Query(): %w", err) } defer func() { - if err != nil { - closeErr := rows.Close() + if returnedError != nil { + closeErr := rows.Close() // NOTE: Not `returnedRows.Close()`! We may have `rows != nil && returnedRows == nil`. if closeErr != nil { - err = fmt.Errorf("%w (additional error during rows.Close(): %s)", err, closeErr.Error()) + returnedError = fmt.Errorf("%w (additional error during rows.Close(): %s)", returnedError, closeErr.Error()) } } }() @@ -103,8 +103,8 @@ func startSelectQuery(db Handle, plan internal.Plan, query string, args ...any) indexes[idx], ok = plan.IndexByColumnName[columnName] if !ok { return nil, nil, fmt.Errorf( - "result has column %q in position %d, but no field in record type has `db:%[1]q`", - columnName, idx, + "result has column %q in position %d, but no field in type %s has `db:%[1]q`", + columnName, idx, plan.TypeName, ) } } @@ -125,11 +125,7 @@ func collectRow(rows *sql.Rows, v reflect.Value, slots []any, indexes [][]int) e for idx, index := range indexes { slots[idx] = v.FieldByIndex(index).Addr().Interface() } - err := rows.Scan(slots...) - if err != nil { - return fmt.Errorf("during rows.Scan(): %w", err) - } - return nil + return rows.Scan(slots...) } func mergeCloseError(typeName string, err, closeErr error) error { @@ -147,7 +143,6 @@ func mergeCloseError(typeName string, err, closeErr error) error { // according to the column names reported by the database as part of the result set. // // If there are no rows in the result set, [sql.ErrNoRows] is returned. -// If there are multiple rows in the result set, [ErrMultipleRows] is returned. // // Warning: Because of limitations in the interface of database/sql, this function is built on [Store.Select] and cannot be any faster than it. // For maximum performance, use [Store.SelectOneWhere] which avoids the overhead of potentially having to read multiple rows. @@ -158,13 +153,10 @@ func (s Store[R]) SelectOne(db Handle, query string, args ...any) (result R, err var results []R results, err = s.Select(db, query, args...) if err == nil { - switch len(results) { - case 0: + if len(results) == 0 { err = sql.ErrNoRows - case 1: + } else { result = results[0] - default: - err = ErrMultipleRows } } return @@ -173,7 +165,8 @@ func (s Store[R]) SelectOne(db Handle, query string, args ...any) (result R, err // SelectOneWhere is like [Store.SelectOne], but you only provide the part of the SELECT query that comes after the WHERE. // See [Store.SelectWhere] for an explanation of how the full query is constructed from this partial query. // -// This method is significantly more efficient than [Store.SelectWhere] and using it is recommended when possible. +// This method is significantly more efficient than [Store.SelectOne]. +// Prefer using it instaed of [Store.SelectOne] whenever possible. func (s Store[R]) SelectOneWhere(db Handle, partialQuery string, args ...any) (result R, err error) { // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. // Any expression that does not depend on type R should be factored out into a reusable function. diff --git a/select_test.go b/select_test.go new file mode 100644 index 0000000..d678aa2 --- /dev/null +++ b/select_test.go @@ -0,0 +1,222 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net> +// SPDX-License-Identifier: Apache-2.0 + +package oblast_test + +import ( + "database/sql" + "testing" + "time" + + "go.xyrillian.de/oblast" + "go.xyrillian.de/oblast/internal/assert" + "go.xyrillian.de/oblast/internal/mock" + "go.xyrillian.de/oblast/internal/must" +) + +func TestSelectReturningSomeRecords(t *testing.T) { + md := mock.NewDriver() + db := sql.OpenDB(md) + + type basicRecord struct { + ID int64 `db:"id"` + Name string `db:"name"` + } + store := must.Return(oblast.NewStore[basicRecord]( + oblast.SqliteDialect(), + oblast.TableNameIs("basic_records"), + oblast.PrimaryKeyIs("id"), + ))(t) + + t.Run("using Store.Select", func(t *testing.T) { + md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`). + ExpectQueryWithArgs(3). + AndReturnColumns("name", "id"). + WithRow("foo", 1). + WithRow("bar", 2) + records := must.Return(store.Select(db, `SELECT * FROM basic_records WHERE id < ?`, 3))(t) + assert.SliceEqual(t, records, + basicRecord{1, "foo"}, + basicRecord{2, "bar"}, + ) + }) + + t.Run("using Store.SelectWhere", func(t *testing.T) { + md.ForQuery(`SELECT "id", "name" FROM "basic_records" WHERE id < ?`). + ExpectQueryWithArgs(3). + AndReturnColumns("id", "name"). + WithRow(1, "ffoo"). + WithRow(2, "bbar") + records := must.Return(store.SelectWhere(db, `id < ?`, 3))(t) + assert.SliceEqual(t, records, + basicRecord{1, "ffoo"}, + basicRecord{2, "bbar"}, + ) + }) + + t.Run("using Store.SelectOne", func(t *testing.T) { + md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`). + ExpectQueryWithArgs(3). + AndReturnColumns("name", "id"). + WithRow("fffoo", 1). + WithRow("bbbar", 2) + record := must.Return(store.SelectOne(db, `SELECT * FROM basic_records WHERE id < ?`, 3))(t) + assert.Equal(t, record, basicRecord{1, "fffoo"}) + }) + + t.Run("using Store.SelectOneWhere", func(t *testing.T) { + md.ForQuery(`SELECT "id", "name" FROM "basic_records" WHERE id < ?`). + ExpectQueryWithArgs(3). + AndReturnColumns("id", "name"). + WithRow(1, "ffffoo"). + WithRow(2, "bbbbar") + record := must.Return(store.SelectOneWhere(db, `id < ?`, 3))(t) + assert.Equal(t, record, basicRecord{1, "ffffoo"}) + }) +} + +func TestSelectReturningNoRecords(t *testing.T) { + md := mock.NewDriver() + db := sql.OpenDB(md) + + type basicRecord struct { + ID int64 `db:"id"` + Name string `db:"name"` + } + store := must.Return(oblast.NewStore[basicRecord]( + oblast.SqliteDialect(), + oblast.TableNameIs("basic_records"), + oblast.PrimaryKeyIs("id"), + ))(t) + + t.Run("using Store.Select", func(t *testing.T) { + md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`). + ExpectQueryWithArgs(3). + AndReturnColumns("name", "id") + records := must.Return(store.Select(db, `SELECT * FROM basic_records WHERE id < ?`, 3))(t) + assert.SliceEqual(t, records, nil...) + }) + + t.Run("using Store.SelectWhere", func(t *testing.T) { + md.ForQuery(`SELECT "id", "name" FROM "basic_records" WHERE id < ?`). + ExpectQueryWithArgs(3). + AndReturnColumns("id", "name") + records := must.Return(store.SelectWhere(db, `id < ?`, 3))(t) + assert.SliceEqual(t, records, nil...) + }) + + t.Run("using Store.SelectOne", func(t *testing.T) { + md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`). + ExpectQueryWithArgs(3). + AndReturnColumns("name", "id") + _, err := store.SelectOne(db, `SELECT * FROM basic_records WHERE id < ?`, 3) + assert.Equal(t, err.Error(), sql.ErrNoRows.Error()) + }) + + t.Run("using Store.SelectOneWhere", func(t *testing.T) { + md.ForQuery(`SELECT "id", "name" FROM "basic_records" WHERE id < ?`). + ExpectQueryWithArgs(3). + AndReturnColumns("id", "name") + _, err := store.SelectOneWhere(db, `id < ?`, 3) + assert.Equal(t, err.Error(), sql.ErrNoRows.Error()) + }) +} + +func TestSelectIntoUnexpectedField(t *testing.T) { + md := mock.NewDriver() + db := sql.OpenDB(md) + + type basicRecord struct { + ID int64 `db:"id"` + Description string `db:"desc"` // but DB knows only the field "name"! + } + store := must.Return(oblast.NewStore[basicRecord]( + oblast.SqliteDialect(), + oblast.TableNameIs("basic_records"), + oblast.PrimaryKeyIs("id"), + ))(t) + + expectedError := "result has column \"name\" in position 0, but no field in type basicRecord has `db:\"name\"`" + + // NOTE: This problem cannot occur with SelectWhere() and SelectOneWhere() because of their use of query generation. + + t.Run("using Store.Select", func(t *testing.T) { + md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`). + ExpectQueryWithArgs(3). + AndReturnColumns("name", "id"). + WithRow("foo", 1). + WithRow("bar", 2) + _, err := store.Select(db, `SELECT * FROM basic_records WHERE id < ?`, 3) + assert.Equal(t, err.Error(), expectedError) + }) + + t.Run("using Store.SelectOne", func(t *testing.T) { + md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`). + ExpectQueryWithArgs(3). + AndReturnColumns("name", "id"). + WithRow("ffoo", 1). + WithRow("bbar", 2) + _, err := store.SelectOne(db, `SELECT * FROM basic_records WHERE id < ?`, 3) + assert.Equal(t, err.Error(), expectedError) + }) +} + +func TestSelectWithScanError(t *testing.T) { + md := mock.NewDriver() + db := sql.OpenDB(md) + + type basicRecord struct { + ID int64 `db:"id"` + CreatedAt time.Time `db:"created_at"` // but the DB will give us strings that are not timestamps + } + store := must.Return(oblast.NewStore[basicRecord]( + oblast.SqliteDialect(), + oblast.TableNameIs("basic_records"), + oblast.PrimaryKeyIs("id"), + ))(t) + + expectedError := `sql: Scan error on column index 1, name "created_at": unsupported Scan, storing driver.Value type string into type *time.Time` + + t.Run("using Store.Select", func(t *testing.T) { + md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`). + ExpectQueryWithArgs(3). + AndReturnColumns("id", "created_at"). + WithRow(1, "foo"). + WithRow(2, "bar") + _, err := store.Select(db, `SELECT * FROM basic_records WHERE id < ?`, 3) + assert.Equal(t, err.Error(), expectedError) + }) + + t.Run("using Store.SelectWhere", func(t *testing.T) { + md.ForQuery(`SELECT "id", "created_at" FROM "basic_records" WHERE id < ?`). + ExpectQueryWithArgs(3). + AndReturnColumns("id", "created_at"). + WithRow(1, "ffoo"). + WithRow(2, "bbar") + _, err := store.SelectWhere(db, `id < ?`, 3) + assert.Equal(t, err.Error(), expectedError) + }) + + t.Run("using Store.SelectOne", func(t *testing.T) { + md.ForQuery(`SELECT * FROM basic_records WHERE id < ?`). + ExpectQueryWithArgs(3). + AndReturnColumns("id", "created_at"). + WithRow(1, "fffoo"). + WithRow(2, "bbbar") + _, err := store.SelectOne(db, `SELECT * FROM basic_records WHERE id < ?`, 3) + assert.Equal(t, err.Error(), expectedError) + }) + + t.Run("using Store.SelectOneWhere", func(t *testing.T) { + md.ForQuery(`SELECT "id", "created_at" FROM "basic_records" WHERE id < ?`). + ExpectQueryWithArgs(3). + AndReturnColumns("id", "created_at"). + WithRow(1, "ffffoo"). + WithRow(2, "bbbbar") + _, err := store.SelectOneWhere(db, `id < ?`, 3) + assert.Equal(t, err.Error(), expectedError) + }) +} + +// TODO: test error capture during Rows.Close() +// TODO: check for maximum test coverage in select.go |
