aboutsummaryrefslogtreecommitdiff
path: root/internal/mock/mock.go
blob: d3358c46f64c5a1bd5d16b985718cda7e11a0532 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
// SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net>
// SPDX-License-Identifier: Apache-2.0

package mock

import (
	"context"
	"database/sql/driver"
	"errors"
	"fmt"
	"io"
	"reflect"
	"slices"
	"strings"
)

////////////////////////////////////////////////////////////////////////////////
// type Driver

// Driver is a mock SQL driver that only accepts queries that were preannounced.
type Driver struct {
	responseSetsByQuery map[string]*ResponseSet
}

// assert that interface is implemented
var _ driver.Connector = &Driver{}

// NewDriver instantiates a new driver.
// The result returns [driver.Connector] and can be given to [sql.OpenDB].
func NewDriver() *Driver {
	return &Driver{
		responseSetsByQuery: make(map[string]*ResponseSet),
	}
}

// Connect implements the [driver.Connector] interface.
func (d *Driver) Connect(ctx context.Context) (driver.Conn, error) {
	return &connection{d: d}, nil
}

// Driver implements the [driver.Connector] interface.
func (d *Driver) Driver() driver.Driver {
	// Not needed. Implementing the Driver interface would only be necessary if
	// we wanted to use sql.Open() instead of sql.OpenDB(), or if we wanted to
	// use sql.DB.Driver().
	panic("unimplemented")
}

// ForQuery tells the driver to expect the given query string to be sent soon.
// The return value can be used to plan what to return when the query is actually executed.
func (d *Driver) ForQuery(query string) *ResponseSet {
	if d.responseSetsByQuery[query] == nil {
		d.responseSetsByQuery[query] = &ResponseSet{}
	}
	return d.responseSetsByQuery[query]
}

////////////////////////////////////////////////////////////////////////////////
// type ResponseSet

// ResponseSet is a set of mock responses for a query sent to type [Driver].
type ResponseSet struct {
	expectedExecs   []expectation[Result]
	expectedQueries []expectation[Rows]
}

type expectation[T any] struct {
	args   []driver.Value
	output *T
}

func newExpectation[T any](args []any) expectation[T] {
	e := expectation[T]{
		args:   make([]driver.Value, len(args)),
		output: new(T),
	}
	for idx, arg := range args {
		var err error
		e.args[idx], err = driver.DefaultParameterConverter.ConvertValue(arg)
		if err != nil {
			panic(fmt.Sprintf("could not convert value %#v into driver.Value: %s", arg, err.Error()))
		}
	}
	return e
}

// ExpectExecWithArgs plans a response to an Exec() call.
func (rs *ResponseSet) ExpectExecWithArgs(args ...any) *Result {
	e := newExpectation[Result](args)
	rs.expectedExecs = append(rs.expectedExecs, e)
	return e.output
}

// ExpectQueryWithArgs plans a response to a Query() or QueryRows() call.
func (rs *ResponseSet) ExpectQueryWithArgs(args ...any) *Rows {
	e := newExpectation[Rows](args)
	rs.expectedQueries = append(rs.expectedQueries, e)
	return e.output
}

////////////////////////////////////////////////////////////////////////////////
// type connection

type connection struct {
	d      *Driver
	closed bool
}

// Prepare implements the [driver.Conn] interface.
func (c *connection) Prepare(query string) (driver.Stmt, error) {
	rs := c.d.responseSetsByQuery[query]
	if rs == nil {
		return nil, fmt.Errorf("unexpected query: %s", query)
	}
	return &statement{c: c, query: query, rs: rs}, nil
}

// Close implements the [driver.Conn] interface.
func (c *connection) Close() error {
	c.closed = true
	return nil
}

// Begin implements the [driver.Conn] interface.
func (c *connection) Begin() (driver.Tx, error) {
	return transaction{}, nil
}

////////////////////////////////////////////////////////////////////////////////
// type transaction

type transaction struct{}

// Commit implements the [driver.Tx] interface.
func (t transaction) Commit() error {
	return nil // unused
}

// Rollback implements the [driver.Tx] interface.
func (t transaction) Rollback() error {
	return nil // unused
}

////////////////////////////////////////////////////////////////////////////////
// type statement

type statement struct {
	c      *connection
	query  string
	rs     *ResponseSet
	closed bool
}

// Close implements the [driver.Stmt] interface.
func (s *statement) Close() error {
	return nil
}

// NumInput implements the [driver.Stmt] interface.
func (s *statement) NumInput() int {
	return strings.Count(s.query, "?") // NOTE: extremely crude, but does the job for us
}

// Exec implements the [driver.Stmt] interface.
func (s *statement) Exec(args []driver.Value) (driver.Result, error) {
	if s.closed {
		return nil, errors.New("statement was closed")
	}
	if s.c.closed {
		return nil, errors.New("connection was closed")
	}
	for idx, e := range s.rs.expectedExecs {
		if reflect.DeepEqual(e.args, args) {
			s.rs.expectedExecs = slices.Delete(s.rs.expectedExecs, idx, idx+1)
			return result{r: *e.output}, nil
		}
	}
	return nil, fmt.Errorf("unexpected arguments for query %q: %#v", s.query, args)
}

// Query implements the [driver.Stmt] interface.
func (s *statement) Query(args []driver.Value) (driver.Rows, error) {
	if s.closed {
		return nil, errors.New("statement was closed")
	}
	if s.c.closed {
		return nil, errors.New("connection was closed")
	}
	for idx, e := range s.rs.expectedQueries {
		if reflect.DeepEqual(e.args, args) {
			s.rs.expectedQueries = slices.Delete(s.rs.expectedQueries, idx, idx+1)
			return &rows{r: *e.output}, nil
		}
	}
	return nil, fmt.Errorf("unexpected arguments for query %q: %#v", s.query, args)
}

///////////////////////////////////////////////////////////////////////////////////////////
// type Result

// Result is a mock response for an Exec() call.
// It is constructed by [ResponseSet.ExpectExec].
type Result struct {
	lastInsertId *int64
	rowsAffected *int64
}

// AndReturnLastInsertId configures a mock LastInsertId() value for this Result.
// Returns the same Result instance to allow chaining additional method calls.
func (r *Result) AndReturnLastInsertId(id int64) *Result {
	r.lastInsertId = &id
	return r
}

// AndReturnRowsAffected configures a mock RowsAffected() value for this Result.
// Returns the same Result instance to allow chaining additional method calls.
func (r *Result) AndReturnRowsAffected(count int64) *Result {
	r.rowsAffected = &count
	return r
}

type result struct {
	r Result
}

// LastInsertId implements the [driver.Result] interface.
func (r result) LastInsertId() (int64, error) {
	if r.r.lastInsertId == nil {
		return 0, errors.New("AndReturnLastInsertId() was not called for this Result")
	}
	return *r.r.lastInsertId, nil
}

// RowsAffected implements the [driver.Result] interface.
func (r result) RowsAffected() (int64, error) {
	if r.r.rowsAffected == nil {
		return 0, errors.New("AndReturnRowsAffected() was not called for this Result")
	}
	return *r.r.rowsAffected, nil
}

// /////////////////////////////////////////////////////////////////////////////////////////
// type Rows

// Rows is a mock response for a Query() or QueryRow() call.
// It is constructed by [ResponseSet.ExpectQuery].
type Rows struct {
	columns []string
	results [][]any
}

// AndReturnColumns configures the set of column names that will be returend by this query.
// Returns the same Result instance to allow chaining additional method calls.
func (r *Rows) AndReturnColumns(columns ...string) *Rows {
	if len(r.columns) > 0 {
		panic("AndReturnColumns() called multiple times for the same Rows object")
	}
	r.columns = columns
	return r
}

// WithRow adds a row to the result set that will be returned by this query.
// This may only be called after AndReturnColumns().
func (r *Rows) WithRow(values ...any) *Rows {
	if len(r.columns) == 0 {
		panic("AndReturnColumns() has not been called for this Rows object yet")
	}
	if len(r.columns) != len(values) {
		panic("WithRow() must be called with the same number of args as the preceding AndReturnColumns() call")
	}
	r.results = append(r.results, values)
	return r
}

type rows struct {
	r      Rows
	closed bool
}

// Columns implements the [driver.Rows] interface.
func (r *rows) Columns() []string {
	return r.r.columns
}

// Close implements the [driver.Rows] interface.
func (r *rows) Close() error {
	r.closed = true
	return nil
}

// Next implements the [driver.Rows] interface.
func (r *rows) Next(dest []driver.Value) error {
	if r.closed {
		return errors.New("rows object was closed")
	}
	if len(r.r.results) == 0 {
		return io.EOF
	}
	for idx, value := range r.r.results[0] {
		dest[idx] = value
	}
	r.r.results = r.r.results[1:]
	return nil
}