From e45a8be0dcfc375963a061d83e04429995053da1 Mon Sep 17 00:00:00 2001 From: Stefan Majewsky Date: Fri, 24 Apr 2026 15:46:24 +0200 Subject: exclude testhelpers from coverage testing --- Makefile | 2 +- benchmark/benchmark_test.go | 4 +- errors_test.go | 2 +- internal/assert/assert.go | 63 ------- internal/mock/mock.go | 320 ---------------------------------- internal/must/must.go | 26 --- internal/testhelpers/assert/assert.go | 63 +++++++ internal/testhelpers/mock/mock.go | 320 ++++++++++++++++++++++++++++++++++ internal/testhelpers/must/must.go | 26 +++ plan_test.go | 2 +- query_test.go | 6 +- select_test.go | 6 +- 12 files changed, 420 insertions(+), 420 deletions(-) delete mode 100644 internal/assert/assert.go delete mode 100644 internal/mock/mock.go delete mode 100644 internal/must/must.go create mode 100644 internal/testhelpers/assert/assert.go create mode 100644 internal/testhelpers/mock/mock.go create mode 100644 internal/testhelpers/must/must.go diff --git a/Makefile b/Makefile index 834305f..d364193 100644 --- a/Makefile +++ b/Makefile @@ -15,7 +15,7 @@ static-check: FORCE benchmark: FORCE @cd benchmark && go test -bench . -benchmem . -GO_COVERPKGS := $(shell go list ./... | tr '\n' , | sed 's/,$$//') +GO_COVERPKGS := $(shell go list ./... | grep -vw testhelpers | tr '\n' , | sed 's/,$$//') GO_TESTPKGS := $(shell go list -f '{{if or .TestGoFiles .XTestGoFiles}}{{.ImportPath}}{{end}}' ./...) build/cover.out: FORCE diff --git a/benchmark/benchmark_test.go b/benchmark/benchmark_test.go index d18c0cb..7cf94b7 100644 --- a/benchmark/benchmark_test.go +++ b/benchmark/benchmark_test.go @@ -14,8 +14,8 @@ import ( "github.com/go-gorp/gorp/v3" _ "github.com/mattn/go-sqlite3" "go.xyrillian.de/oblast" - "go.xyrillian.de/oblast/internal/assert" - "go.xyrillian.de/oblast/internal/must" + "go.xyrillian.de/oblast/internal/testhelpers/assert" + "go.xyrillian.de/oblast/internal/testhelpers/must" "gorm.io/driver/sqlite" "gorm.io/gorm" ) diff --git a/errors_test.go b/errors_test.go index c39cf67..81ee395 100644 --- a/errors_test.go +++ b/errors_test.go @@ -7,7 +7,7 @@ import ( "errors" "testing" - "go.xyrillian.de/oblast/internal/assert" + "go.xyrillian.de/oblast/internal/testhelpers/assert" ) type fooError struct{} diff --git a/internal/assert/assert.go b/internal/assert/assert.go deleted file mode 100644 index 6e641ca..0000000 --- a/internal/assert/assert.go +++ /dev/null @@ -1,63 +0,0 @@ -// SPDX-FileCopyrightText: 2026 Stefan Majewsky -// SPDX-License-Identifier: Apache-2.0 - -package assert - -import ( - "cmp" - "errors" - "reflect" - "testing" -) - -// Equal is a test assertion. -func Equal[V comparable](t testing.TB, actual, expected V) { - t.Helper() - if actual != expected { - t.Errorf("expected %#v", expected) - t.Errorf(" but got %#v", actual) - } -} - -// DeepEqual is a test assertion. -func DeepEqual[V any](t testing.TB, actual, expected V) { - t.Helper() - if !reflect.DeepEqual(actual, expected) { - t.Errorf("expected %#v", expected) - t.Errorf(" but got %#v", actual) - } -} - -// ErrEqual is a test assertion. -func ErrEqual(t testing.TB, actual error, expected string) { - t.Helper() - Equal(t, cmp.Or(actual, errors.New("")).Error(), expected) -} - -// SliceEqual is a test assertion. -func SliceEqual[V comparable](t testing.TB, actual []V, expected ...V) { - t.Helper() - if len(actual) != len(expected) { - t.Errorf("length mismatch: expected %#v, but got %#v", expected, actual) - } - for idx := range actual { - if actual[idx] != expected[idx] { - t.Errorf("element %d: expected %#v", idx, expected[idx]) - t.Errorf("element %d: but got %#v", idx, actual[idx]) - } - } -} - -// SliceDeepEqual is a test assertion. -func SliceDeepEqual[V any](t testing.TB, actual []V, expected ...V) { - t.Helper() - if len(actual) != len(expected) { - t.Errorf("length mismatch: expected %#v, but got %#v", expected, actual) - } - for idx := range actual { - if !reflect.DeepEqual(actual[idx], expected[idx]) { - t.Errorf("element %d: expected %#v", idx, expected[idx]) - t.Errorf("element %d: but got %#v", idx, actual[idx]) - } - } -} diff --git a/internal/mock/mock.go b/internal/mock/mock.go deleted file mode 100644 index 6265166..0000000 --- a/internal/mock/mock.go +++ /dev/null @@ -1,320 +0,0 @@ -// SPDX-FileCopyrightText: 2026 Stefan Majewsky -// SPDX-License-Identifier: Apache-2.0 - -package mock - -import ( - "context" - "database/sql/driver" - "errors" - "fmt" - "io" - "reflect" - "slices" - "strings" -) - -//////////////////////////////////////////////////////////////////////////////// -// type Driver - -// Driver is a mock SQL driver that only accepts queries that were preannounced. -type Driver struct { - responseSetsByQuery map[string]*ResponseSet -} - -// assert that interface is implemented -var _ driver.Connector = &Driver{} - -// NewDriver instantiates a new driver. -// The result returns [driver.Connector] and can be given to [sql.OpenDB]. -func NewDriver() *Driver { - return &Driver{ - responseSetsByQuery: make(map[string]*ResponseSet), - } -} - -// Connect implements the [driver.Connector] interface. -func (d *Driver) Connect(ctx context.Context) (driver.Conn, error) { - return &connection{d: d}, nil -} - -// Driver implements the [driver.Connector] interface. -func (d *Driver) Driver() driver.Driver { - // Not needed. Implementing the Driver interface would only be necessary if - // we wanted to use sql.Open() instead of sql.OpenDB(), or if we wanted to - // use sql.DB.Driver(). - panic("unimplemented") -} - -// ForQuery tells the driver to expect the given query string to be sent soon. -// The return value can be used to plan what to return when the query is actually executed. -func (d *Driver) ForQuery(query string) *ResponseSet { - if d.responseSetsByQuery[query] == nil { - d.responseSetsByQuery[query] = &ResponseSet{} - } - return d.responseSetsByQuery[query] -} - -//////////////////////////////////////////////////////////////////////////////// -// type ResponseSet - -// ResponseSet is a set of mock responses for a query sent to type [Driver]. -type ResponseSet struct { - expectedExecs []expectation[Result] - expectedQueries []expectation[Rows] -} - -type expectation[T any] struct { - args []driver.Value - output *T -} - -func newExpectation[T any](args []any) expectation[T] { - e := expectation[T]{ - args: make([]driver.Value, len(args)), - output: new(T), - } - for idx, arg := range args { - var err error - e.args[idx], err = driver.DefaultParameterConverter.ConvertValue(arg) - if err != nil { - panic(fmt.Sprintf("could not convert value %#v into driver.Value: %s", arg, err.Error())) - } - } - return e -} - -// ExpectExecWithArgs plans a response to an Exec() call. -func (rs *ResponseSet) ExpectExecWithArgs(args ...any) *Result { - e := newExpectation[Result](args) - rs.expectedExecs = append(rs.expectedExecs, e) - return e.output -} - -// ExpectQueryWithArgs plans a response to a Query() or QueryRows() call. -func (rs *ResponseSet) ExpectQueryWithArgs(args ...any) *Rows { - e := newExpectation[Rows](args) - rs.expectedQueries = append(rs.expectedQueries, e) - return e.output -} - -//////////////////////////////////////////////////////////////////////////////// -// type connection - -type connection struct { - d *Driver - closed bool -} - -// Prepare implements the [driver.Conn] interface. -func (c *connection) Prepare(query string) (driver.Stmt, error) { - rs := c.d.responseSetsByQuery[query] - if rs == nil { - return nil, fmt.Errorf("unexpected query: %s", query) - } - return &statement{c: c, query: query, rs: rs}, nil -} - -// Close implements the [driver.Conn] interface. -func (c *connection) Close() error { - c.closed = true - return nil -} - -// Begin implements the [driver.Conn] interface. -func (c *connection) Begin() (driver.Tx, error) { - return transaction{}, nil -} - -//////////////////////////////////////////////////////////////////////////////// -// type transaction - -type transaction struct{} - -// Commit implements the [driver.Tx] interface. -func (t transaction) Commit() error { - return nil // unused -} - -// Rollback implements the [driver.Tx] interface. -func (t transaction) Rollback() error { - return nil // unused -} - -//////////////////////////////////////////////////////////////////////////////// -// type statement - -type statement struct { - c *connection - query string - rs *ResponseSet - closed bool -} - -// Close implements the [driver.Stmt] interface. -func (s *statement) Close() error { - return nil -} - -// NumInput implements the [driver.Stmt] interface. -func (s *statement) NumInput() int { - // option 1: when using SQLite dialect, count `?` - count := strings.Count(s.query, "?") - if count > 0 { - return count - } - - // option 2: when using PostgreSQL dialect, find `$1`, `$2`, etc. - for strings.Contains(s.query, fmt.Sprintf("$%d", count+1)) { - count++ - } - return count -} - -// Exec implements the [driver.Stmt] interface. -func (s *statement) Exec(args []driver.Value) (driver.Result, error) { - if s.closed { - return nil, errors.New("statement was closed") - } - if s.c.closed { - return nil, errors.New("connection was closed") - } - for idx, e := range s.rs.expectedExecs { - if reflect.DeepEqual(e.args, args) { - s.rs.expectedExecs = slices.Delete(s.rs.expectedExecs, idx, idx+1) - return result{r: *e.output}, nil - } - } - return nil, fmt.Errorf("unexpected arguments for query %q: %#v", s.query, args) -} - -// Query implements the [driver.Stmt] interface. -func (s *statement) Query(args []driver.Value) (driver.Rows, error) { - if s.closed { - return nil, errors.New("statement was closed") - } - if s.c.closed { - return nil, errors.New("connection was closed") - } - for idx, e := range s.rs.expectedQueries { - if reflect.DeepEqual(e.args, args) { - s.rs.expectedQueries = slices.Delete(s.rs.expectedQueries, idx, idx+1) - return &rows{r: *e.output}, nil - } - } - return nil, fmt.Errorf("unexpected arguments for query %q: %#v", s.query, args) -} - -/////////////////////////////////////////////////////////////////////////////////////////// -// type Result - -// Result is a mock response for an Exec() call. -// It is constructed by [ResponseSet.ExpectExec]. -type Result struct { - lastInsertId *int64 - rowsAffected *int64 -} - -// AndReturnLastInsertId configures a mock LastInsertId() value for this Result. -// Returns the same Result instance to allow chaining additional method calls. -func (r *Result) AndReturnLastInsertId(id int64) *Result { - r.lastInsertId = &id - return r -} - -// AndReturnRowsAffected configures a mock RowsAffected() value for this Result. -// Returns the same Result instance to allow chaining additional method calls. -func (r *Result) AndReturnRowsAffected(count int64) *Result { - r.rowsAffected = &count - return r -} - -type result struct { - r Result -} - -// LastInsertId implements the [driver.Result] interface. -func (r result) LastInsertId() (int64, error) { - if r.r.lastInsertId == nil { - return 0, errors.New("AndReturnLastInsertId() was not called for this Result") - } - return *r.r.lastInsertId, nil -} - -// RowsAffected implements the [driver.Result] interface. -func (r result) RowsAffected() (int64, error) { - if r.r.rowsAffected == nil { - return 0, errors.New("AndReturnRowsAffected() was not called for this Result") - } - return *r.r.rowsAffected, nil -} - -// ///////////////////////////////////////////////////////////////////////////////////////// -// type Rows - -// Rows is a mock response for a Query() or QueryRow() call. -// It is constructed by [ResponseSet.ExpectQuery]. -type Rows struct { - columns []string - results [][]any - closeError error -} - -// AndReturnColumns configures the set of column names that will be returend by this query. -// Returns the same Result instance to allow chaining additional method calls. -func (r *Rows) AndReturnColumns(columns ...string) *Rows { - if len(r.columns) > 0 { - panic("AndReturnColumns() called multiple times for the same Rows object") - } - r.columns = columns - return r -} - -// WithRow adds a row to the result set that will be returned by this query. -// This may only be called after AndReturnColumns(). -func (r *Rows) WithRow(values ...any) *Rows { - if len(r.columns) == 0 { - panic("AndReturnColumns() has not been called for this Rows object yet") - } - if len(r.columns) != len(values) { - panic("WithRow() must be called with the same number of args as the preceding AndReturnColumns() call") - } - r.results = append(r.results, values) - return r -} - -// AndCloseFailsWith sets up Close() for this Rows to fail with the provided error message. -func (r *Rows) AndCloseFailsWith(err error) { - r.closeError = err -} - -type rows struct { - r Rows - closed bool -} - -// Columns implements the [driver.Rows] interface. -func (r *rows) Columns() []string { - return r.r.columns -} - -// Close implements the [driver.Rows] interface. -func (r *rows) Close() error { - r.closed = true - return r.r.closeError -} - -// Next implements the [driver.Rows] interface. -func (r *rows) Next(dest []driver.Value) error { - if r.closed { - return errors.New("rows object was closed") - } - if len(r.r.results) == 0 { - return io.EOF - } - for idx, value := range r.r.results[0] { - dest[idx] = value - } - r.r.results = r.r.results[1:] - return nil -} diff --git a/internal/must/must.go b/internal/must/must.go deleted file mode 100644 index 7a137c6..0000000 --- a/internal/must/must.go +++ /dev/null @@ -1,26 +0,0 @@ -// SPDX-FileCopyrightText: 2026 Stefan Majewsky -// SPDX-License-Identifier: Apache-2.0 - -package must - -import "testing" - -// Succeed fails the test if err is not nil. -func Succeed(t testing.TB, err error) { - t.Helper() - if err != nil { - t.Fatal(err.Error()) - } -} - -// Return wraps a function returning two output values, -// and either forwards the result value on success, or fails the test on error. -func Return[V any](value V, err error) func(testing.TB) V { - return func(t testing.TB) V { - t.Helper() - if err != nil { - t.Fatal(err.Error()) - } - return value - } -} diff --git a/internal/testhelpers/assert/assert.go b/internal/testhelpers/assert/assert.go new file mode 100644 index 0000000..6e641ca --- /dev/null +++ b/internal/testhelpers/assert/assert.go @@ -0,0 +1,63 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky +// SPDX-License-Identifier: Apache-2.0 + +package assert + +import ( + "cmp" + "errors" + "reflect" + "testing" +) + +// Equal is a test assertion. +func Equal[V comparable](t testing.TB, actual, expected V) { + t.Helper() + if actual != expected { + t.Errorf("expected %#v", expected) + t.Errorf(" but got %#v", actual) + } +} + +// DeepEqual is a test assertion. +func DeepEqual[V any](t testing.TB, actual, expected V) { + t.Helper() + if !reflect.DeepEqual(actual, expected) { + t.Errorf("expected %#v", expected) + t.Errorf(" but got %#v", actual) + } +} + +// ErrEqual is a test assertion. +func ErrEqual(t testing.TB, actual error, expected string) { + t.Helper() + Equal(t, cmp.Or(actual, errors.New("")).Error(), expected) +} + +// SliceEqual is a test assertion. +func SliceEqual[V comparable](t testing.TB, actual []V, expected ...V) { + t.Helper() + if len(actual) != len(expected) { + t.Errorf("length mismatch: expected %#v, but got %#v", expected, actual) + } + for idx := range actual { + if actual[idx] != expected[idx] { + t.Errorf("element %d: expected %#v", idx, expected[idx]) + t.Errorf("element %d: but got %#v", idx, actual[idx]) + } + } +} + +// SliceDeepEqual is a test assertion. +func SliceDeepEqual[V any](t testing.TB, actual []V, expected ...V) { + t.Helper() + if len(actual) != len(expected) { + t.Errorf("length mismatch: expected %#v, but got %#v", expected, actual) + } + for idx := range actual { + if !reflect.DeepEqual(actual[idx], expected[idx]) { + t.Errorf("element %d: expected %#v", idx, expected[idx]) + t.Errorf("element %d: but got %#v", idx, actual[idx]) + } + } +} diff --git a/internal/testhelpers/mock/mock.go b/internal/testhelpers/mock/mock.go new file mode 100644 index 0000000..6265166 --- /dev/null +++ b/internal/testhelpers/mock/mock.go @@ -0,0 +1,320 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky +// SPDX-License-Identifier: Apache-2.0 + +package mock + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" + "io" + "reflect" + "slices" + "strings" +) + +//////////////////////////////////////////////////////////////////////////////// +// type Driver + +// Driver is a mock SQL driver that only accepts queries that were preannounced. +type Driver struct { + responseSetsByQuery map[string]*ResponseSet +} + +// assert that interface is implemented +var _ driver.Connector = &Driver{} + +// NewDriver instantiates a new driver. +// The result returns [driver.Connector] and can be given to [sql.OpenDB]. +func NewDriver() *Driver { + return &Driver{ + responseSetsByQuery: make(map[string]*ResponseSet), + } +} + +// Connect implements the [driver.Connector] interface. +func (d *Driver) Connect(ctx context.Context) (driver.Conn, error) { + return &connection{d: d}, nil +} + +// Driver implements the [driver.Connector] interface. +func (d *Driver) Driver() driver.Driver { + // Not needed. Implementing the Driver interface would only be necessary if + // we wanted to use sql.Open() instead of sql.OpenDB(), or if we wanted to + // use sql.DB.Driver(). + panic("unimplemented") +} + +// ForQuery tells the driver to expect the given query string to be sent soon. +// The return value can be used to plan what to return when the query is actually executed. +func (d *Driver) ForQuery(query string) *ResponseSet { + if d.responseSetsByQuery[query] == nil { + d.responseSetsByQuery[query] = &ResponseSet{} + } + return d.responseSetsByQuery[query] +} + +//////////////////////////////////////////////////////////////////////////////// +// type ResponseSet + +// ResponseSet is a set of mock responses for a query sent to type [Driver]. +type ResponseSet struct { + expectedExecs []expectation[Result] + expectedQueries []expectation[Rows] +} + +type expectation[T any] struct { + args []driver.Value + output *T +} + +func newExpectation[T any](args []any) expectation[T] { + e := expectation[T]{ + args: make([]driver.Value, len(args)), + output: new(T), + } + for idx, arg := range args { + var err error + e.args[idx], err = driver.DefaultParameterConverter.ConvertValue(arg) + if err != nil { + panic(fmt.Sprintf("could not convert value %#v into driver.Value: %s", arg, err.Error())) + } + } + return e +} + +// ExpectExecWithArgs plans a response to an Exec() call. +func (rs *ResponseSet) ExpectExecWithArgs(args ...any) *Result { + e := newExpectation[Result](args) + rs.expectedExecs = append(rs.expectedExecs, e) + return e.output +} + +// ExpectQueryWithArgs plans a response to a Query() or QueryRows() call. +func (rs *ResponseSet) ExpectQueryWithArgs(args ...any) *Rows { + e := newExpectation[Rows](args) + rs.expectedQueries = append(rs.expectedQueries, e) + return e.output +} + +//////////////////////////////////////////////////////////////////////////////// +// type connection + +type connection struct { + d *Driver + closed bool +} + +// Prepare implements the [driver.Conn] interface. +func (c *connection) Prepare(query string) (driver.Stmt, error) { + rs := c.d.responseSetsByQuery[query] + if rs == nil { + return nil, fmt.Errorf("unexpected query: %s", query) + } + return &statement{c: c, query: query, rs: rs}, nil +} + +// Close implements the [driver.Conn] interface. +func (c *connection) Close() error { + c.closed = true + return nil +} + +// Begin implements the [driver.Conn] interface. +func (c *connection) Begin() (driver.Tx, error) { + return transaction{}, nil +} + +//////////////////////////////////////////////////////////////////////////////// +// type transaction + +type transaction struct{} + +// Commit implements the [driver.Tx] interface. +func (t transaction) Commit() error { + return nil // unused +} + +// Rollback implements the [driver.Tx] interface. +func (t transaction) Rollback() error { + return nil // unused +} + +//////////////////////////////////////////////////////////////////////////////// +// type statement + +type statement struct { + c *connection + query string + rs *ResponseSet + closed bool +} + +// Close implements the [driver.Stmt] interface. +func (s *statement) Close() error { + return nil +} + +// NumInput implements the [driver.Stmt] interface. +func (s *statement) NumInput() int { + // option 1: when using SQLite dialect, count `?` + count := strings.Count(s.query, "?") + if count > 0 { + return count + } + + // option 2: when using PostgreSQL dialect, find `$1`, `$2`, etc. + for strings.Contains(s.query, fmt.Sprintf("$%d", count+1)) { + count++ + } + return count +} + +// Exec implements the [driver.Stmt] interface. +func (s *statement) Exec(args []driver.Value) (driver.Result, error) { + if s.closed { + return nil, errors.New("statement was closed") + } + if s.c.closed { + return nil, errors.New("connection was closed") + } + for idx, e := range s.rs.expectedExecs { + if reflect.DeepEqual(e.args, args) { + s.rs.expectedExecs = slices.Delete(s.rs.expectedExecs, idx, idx+1) + return result{r: *e.output}, nil + } + } + return nil, fmt.Errorf("unexpected arguments for query %q: %#v", s.query, args) +} + +// Query implements the [driver.Stmt] interface. +func (s *statement) Query(args []driver.Value) (driver.Rows, error) { + if s.closed { + return nil, errors.New("statement was closed") + } + if s.c.closed { + return nil, errors.New("connection was closed") + } + for idx, e := range s.rs.expectedQueries { + if reflect.DeepEqual(e.args, args) { + s.rs.expectedQueries = slices.Delete(s.rs.expectedQueries, idx, idx+1) + return &rows{r: *e.output}, nil + } + } + return nil, fmt.Errorf("unexpected arguments for query %q: %#v", s.query, args) +} + +/////////////////////////////////////////////////////////////////////////////////////////// +// type Result + +// Result is a mock response for an Exec() call. +// It is constructed by [ResponseSet.ExpectExec]. +type Result struct { + lastInsertId *int64 + rowsAffected *int64 +} + +// AndReturnLastInsertId configures a mock LastInsertId() value for this Result. +// Returns the same Result instance to allow chaining additional method calls. +func (r *Result) AndReturnLastInsertId(id int64) *Result { + r.lastInsertId = &id + return r +} + +// AndReturnRowsAffected configures a mock RowsAffected() value for this Result. +// Returns the same Result instance to allow chaining additional method calls. +func (r *Result) AndReturnRowsAffected(count int64) *Result { + r.rowsAffected = &count + return r +} + +type result struct { + r Result +} + +// LastInsertId implements the [driver.Result] interface. +func (r result) LastInsertId() (int64, error) { + if r.r.lastInsertId == nil { + return 0, errors.New("AndReturnLastInsertId() was not called for this Result") + } + return *r.r.lastInsertId, nil +} + +// RowsAffected implements the [driver.Result] interface. +func (r result) RowsAffected() (int64, error) { + if r.r.rowsAffected == nil { + return 0, errors.New("AndReturnRowsAffected() was not called for this Result") + } + return *r.r.rowsAffected, nil +} + +// ///////////////////////////////////////////////////////////////////////////////////////// +// type Rows + +// Rows is a mock response for a Query() or QueryRow() call. +// It is constructed by [ResponseSet.ExpectQuery]. +type Rows struct { + columns []string + results [][]any + closeError error +} + +// AndReturnColumns configures the set of column names that will be returend by this query. +// Returns the same Result instance to allow chaining additional method calls. +func (r *Rows) AndReturnColumns(columns ...string) *Rows { + if len(r.columns) > 0 { + panic("AndReturnColumns() called multiple times for the same Rows object") + } + r.columns = columns + return r +} + +// WithRow adds a row to the result set that will be returned by this query. +// This may only be called after AndReturnColumns(). +func (r *Rows) WithRow(values ...any) *Rows { + if len(r.columns) == 0 { + panic("AndReturnColumns() has not been called for this Rows object yet") + } + if len(r.columns) != len(values) { + panic("WithRow() must be called with the same number of args as the preceding AndReturnColumns() call") + } + r.results = append(r.results, values) + return r +} + +// AndCloseFailsWith sets up Close() for this Rows to fail with the provided error message. +func (r *Rows) AndCloseFailsWith(err error) { + r.closeError = err +} + +type rows struct { + r Rows + closed bool +} + +// Columns implements the [driver.Rows] interface. +func (r *rows) Columns() []string { + return r.r.columns +} + +// Close implements the [driver.Rows] interface. +func (r *rows) Close() error { + r.closed = true + return r.r.closeError +} + +// Next implements the [driver.Rows] interface. +func (r *rows) Next(dest []driver.Value) error { + if r.closed { + return errors.New("rows object was closed") + } + if len(r.r.results) == 0 { + return io.EOF + } + for idx, value := range r.r.results[0] { + dest[idx] = value + } + r.r.results = r.r.results[1:] + return nil +} diff --git a/internal/testhelpers/must/must.go b/internal/testhelpers/must/must.go new file mode 100644 index 0000000..7a137c6 --- /dev/null +++ b/internal/testhelpers/must/must.go @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky +// SPDX-License-Identifier: Apache-2.0 + +package must + +import "testing" + +// Succeed fails the test if err is not nil. +func Succeed(t testing.TB, err error) { + t.Helper() + if err != nil { + t.Fatal(err.Error()) + } +} + +// Return wraps a function returning two output values, +// and either forwards the result value on success, or fails the test on error. +func Return[V any](value V, err error) func(testing.TB) V { + return func(t testing.TB) V { + t.Helper() + if err != nil { + t.Fatal(err.Error()) + } + return value + } +} diff --git a/plan_test.go b/plan_test.go index b3eeac5..5279a13 100644 --- a/plan_test.go +++ b/plan_test.go @@ -10,7 +10,7 @@ import ( "testing" "time" - "go.xyrillian.de/oblast/internal/assert" + "go.xyrillian.de/oblast/internal/testhelpers/assert" ) func TestPlanFieldTraversal(t *testing.T) { diff --git a/query_test.go b/query_test.go index 29cb015..2809f6e 100644 --- a/query_test.go +++ b/query_test.go @@ -9,9 +9,9 @@ import ( "testing" "go.xyrillian.de/oblast" - "go.xyrillian.de/oblast/internal/assert" - "go.xyrillian.de/oblast/internal/mock" - "go.xyrillian.de/oblast/internal/must" + "go.xyrillian.de/oblast/internal/testhelpers/assert" + "go.xyrillian.de/oblast/internal/testhelpers/mock" + "go.xyrillian.de/oblast/internal/testhelpers/must" ) func TestInsertBasic(t *testing.T) { diff --git a/select_test.go b/select_test.go index 51b0912..f38fbdd 100644 --- a/select_test.go +++ b/select_test.go @@ -10,9 +10,9 @@ import ( "time" "go.xyrillian.de/oblast" - "go.xyrillian.de/oblast/internal/assert" - "go.xyrillian.de/oblast/internal/mock" - "go.xyrillian.de/oblast/internal/must" + "go.xyrillian.de/oblast/internal/testhelpers/assert" + "go.xyrillian.de/oblast/internal/testhelpers/mock" + "go.xyrillian.de/oblast/internal/testhelpers/must" ) func TestSelectReturningSomeRecords(t *testing.T) { -- cgit v1.2.3