From d75031ffd1667c330ccc281ea330503eaeaea88a Mon Sep 17 00:00:00 2001 From: Stefan Majewsky Date: Tue, 14 Apr 2026 00:50:20 +0200 Subject: fold package internal into package oblast --- internal/dialect.go | 41 ------- internal/mock/driver.go | 304 ------------------------------------------------ internal/mock/mock.go | 304 ++++++++++++++++++++++++++++++++++++++++++++++++ internal/plan.go | 302 ----------------------------------------------- internal/plan_test.go | 277 ------------------------------------------- 5 files changed, 304 insertions(+), 924 deletions(-) delete mode 100644 internal/dialect.go delete mode 100644 internal/mock/driver.go create mode 100644 internal/mock/mock.go delete mode 100644 internal/plan.go delete mode 100644 internal/plan_test.go (limited to 'internal') diff --git a/internal/dialect.go b/internal/dialect.go deleted file mode 100644 index e6db5b8..0000000 --- a/internal/dialect.go +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-FileCopyrightText: 2026 Stefan Majewsky -// SPDX-License-Identifier: Apache-2.0 - -package internal - -import ( - "strconv" - "strings" -) - -// Dialect is a copy of the interface of the same name in package oblast. -// We cannot refer to that interface within this package because that would constitute a cyclic dependency. -type Dialect interface { - Placeholder(i int) string - QuoteIdentifier(name string) string - UsesLastInsertID() bool - InsertSuffixForAutoColumns(columns []string) string -} - -// PostgresDialect is the dialect of PostgreSQL databases. -type PostgresDialect struct{} - -func (PostgresDialect) Placeholder(i int) string { return "$" + strconv.Itoa(i+1) } -func (PostgresDialect) QuoteIdentifier(name string) string { return `"` + name + `"` } -func (PostgresDialect) UsesLastInsertID() bool { return false } - -func (p PostgresDialect) InsertSuffixForAutoColumns(columns []string) string { - quotedColumns := make([]string, len(columns)) - for idx, name := range columns { - quotedColumns[idx] = p.QuoteIdentifier(name) - } - return ` RETURNING ` + strings.Join(quotedColumns, ", ") -} - -// SqliteDialect is the dialect of SQLite databases. -type SqliteDialect struct{} - -func (SqliteDialect) Placeholder(_ int) string { return "?" } -func (SqliteDialect) QuoteIdentifier(name string) string { return `"` + name + `"` } -func (SqliteDialect) UsesLastInsertID() bool { return true } -func (SqliteDialect) InsertSuffixForAutoColumns(columns []string) string { return "" } diff --git a/internal/mock/driver.go b/internal/mock/driver.go deleted file mode 100644 index d3358c4..0000000 --- a/internal/mock/driver.go +++ /dev/null @@ -1,304 +0,0 @@ -// SPDX-FileCopyrightText: 2026 Stefan Majewsky -// SPDX-License-Identifier: Apache-2.0 - -package mock - -import ( - "context" - "database/sql/driver" - "errors" - "fmt" - "io" - "reflect" - "slices" - "strings" -) - -//////////////////////////////////////////////////////////////////////////////// -// type Driver - -// Driver is a mock SQL driver that only accepts queries that were preannounced. -type Driver struct { - responseSetsByQuery map[string]*ResponseSet -} - -// assert that interface is implemented -var _ driver.Connector = &Driver{} - -// NewDriver instantiates a new driver. -// The result returns [driver.Connector] and can be given to [sql.OpenDB]. -func NewDriver() *Driver { - return &Driver{ - responseSetsByQuery: make(map[string]*ResponseSet), - } -} - -// Connect implements the [driver.Connector] interface. -func (d *Driver) Connect(ctx context.Context) (driver.Conn, error) { - return &connection{d: d}, nil -} - -// Driver implements the [driver.Connector] interface. -func (d *Driver) Driver() driver.Driver { - // Not needed. Implementing the Driver interface would only be necessary if - // we wanted to use sql.Open() instead of sql.OpenDB(), or if we wanted to - // use sql.DB.Driver(). - panic("unimplemented") -} - -// ForQuery tells the driver to expect the given query string to be sent soon. -// The return value can be used to plan what to return when the query is actually executed. -func (d *Driver) ForQuery(query string) *ResponseSet { - if d.responseSetsByQuery[query] == nil { - d.responseSetsByQuery[query] = &ResponseSet{} - } - return d.responseSetsByQuery[query] -} - -//////////////////////////////////////////////////////////////////////////////// -// type ResponseSet - -// ResponseSet is a set of mock responses for a query sent to type [Driver]. -type ResponseSet struct { - expectedExecs []expectation[Result] - expectedQueries []expectation[Rows] -} - -type expectation[T any] struct { - args []driver.Value - output *T -} - -func newExpectation[T any](args []any) expectation[T] { - e := expectation[T]{ - args: make([]driver.Value, len(args)), - output: new(T), - } - for idx, arg := range args { - var err error - e.args[idx], err = driver.DefaultParameterConverter.ConvertValue(arg) - if err != nil { - panic(fmt.Sprintf("could not convert value %#v into driver.Value: %s", arg, err.Error())) - } - } - return e -} - -// ExpectExecWithArgs plans a response to an Exec() call. -func (rs *ResponseSet) ExpectExecWithArgs(args ...any) *Result { - e := newExpectation[Result](args) - rs.expectedExecs = append(rs.expectedExecs, e) - return e.output -} - -// ExpectQueryWithArgs plans a response to a Query() or QueryRows() call. -func (rs *ResponseSet) ExpectQueryWithArgs(args ...any) *Rows { - e := newExpectation[Rows](args) - rs.expectedQueries = append(rs.expectedQueries, e) - return e.output -} - -//////////////////////////////////////////////////////////////////////////////// -// type connection - -type connection struct { - d *Driver - closed bool -} - -// Prepare implements the [driver.Conn] interface. -func (c *connection) Prepare(query string) (driver.Stmt, error) { - rs := c.d.responseSetsByQuery[query] - if rs == nil { - return nil, fmt.Errorf("unexpected query: %s", query) - } - return &statement{c: c, query: query, rs: rs}, nil -} - -// Close implements the [driver.Conn] interface. -func (c *connection) Close() error { - c.closed = true - return nil -} - -// Begin implements the [driver.Conn] interface. -func (c *connection) Begin() (driver.Tx, error) { - return transaction{}, nil -} - -//////////////////////////////////////////////////////////////////////////////// -// type transaction - -type transaction struct{} - -// Commit implements the [driver.Tx] interface. -func (t transaction) Commit() error { - return nil // unused -} - -// Rollback implements the [driver.Tx] interface. -func (t transaction) Rollback() error { - return nil // unused -} - -//////////////////////////////////////////////////////////////////////////////// -// type statement - -type statement struct { - c *connection - query string - rs *ResponseSet - closed bool -} - -// Close implements the [driver.Stmt] interface. -func (s *statement) Close() error { - return nil -} - -// NumInput implements the [driver.Stmt] interface. -func (s *statement) NumInput() int { - return strings.Count(s.query, "?") // NOTE: extremely crude, but does the job for us -} - -// Exec implements the [driver.Stmt] interface. -func (s *statement) Exec(args []driver.Value) (driver.Result, error) { - if s.closed { - return nil, errors.New("statement was closed") - } - if s.c.closed { - return nil, errors.New("connection was closed") - } - for idx, e := range s.rs.expectedExecs { - if reflect.DeepEqual(e.args, args) { - s.rs.expectedExecs = slices.Delete(s.rs.expectedExecs, idx, idx+1) - return result{r: *e.output}, nil - } - } - return nil, fmt.Errorf("unexpected arguments for query %q: %#v", s.query, args) -} - -// Query implements the [driver.Stmt] interface. -func (s *statement) Query(args []driver.Value) (driver.Rows, error) { - if s.closed { - return nil, errors.New("statement was closed") - } - if s.c.closed { - return nil, errors.New("connection was closed") - } - for idx, e := range s.rs.expectedQueries { - if reflect.DeepEqual(e.args, args) { - s.rs.expectedQueries = slices.Delete(s.rs.expectedQueries, idx, idx+1) - return &rows{r: *e.output}, nil - } - } - return nil, fmt.Errorf("unexpected arguments for query %q: %#v", s.query, args) -} - -/////////////////////////////////////////////////////////////////////////////////////////// -// type Result - -// Result is a mock response for an Exec() call. -// It is constructed by [ResponseSet.ExpectExec]. -type Result struct { - lastInsertId *int64 - rowsAffected *int64 -} - -// AndReturnLastInsertId configures a mock LastInsertId() value for this Result. -// Returns the same Result instance to allow chaining additional method calls. -func (r *Result) AndReturnLastInsertId(id int64) *Result { - r.lastInsertId = &id - return r -} - -// AndReturnRowsAffected configures a mock RowsAffected() value for this Result. -// Returns the same Result instance to allow chaining additional method calls. -func (r *Result) AndReturnRowsAffected(count int64) *Result { - r.rowsAffected = &count - return r -} - -type result struct { - r Result -} - -// LastInsertId implements the [driver.Result] interface. -func (r result) LastInsertId() (int64, error) { - if r.r.lastInsertId == nil { - return 0, errors.New("AndReturnLastInsertId() was not called for this Result") - } - return *r.r.lastInsertId, nil -} - -// RowsAffected implements the [driver.Result] interface. -func (r result) RowsAffected() (int64, error) { - if r.r.rowsAffected == nil { - return 0, errors.New("AndReturnRowsAffected() was not called for this Result") - } - return *r.r.rowsAffected, nil -} - -// ///////////////////////////////////////////////////////////////////////////////////////// -// type Rows - -// Rows is a mock response for a Query() or QueryRow() call. -// It is constructed by [ResponseSet.ExpectQuery]. -type Rows struct { - columns []string - results [][]any -} - -// AndReturnColumns configures the set of column names that will be returend by this query. -// Returns the same Result instance to allow chaining additional method calls. -func (r *Rows) AndReturnColumns(columns ...string) *Rows { - if len(r.columns) > 0 { - panic("AndReturnColumns() called multiple times for the same Rows object") - } - r.columns = columns - return r -} - -// WithRow adds a row to the result set that will be returned by this query. -// This may only be called after AndReturnColumns(). -func (r *Rows) WithRow(values ...any) *Rows { - if len(r.columns) == 0 { - panic("AndReturnColumns() has not been called for this Rows object yet") - } - if len(r.columns) != len(values) { - panic("WithRow() must be called with the same number of args as the preceding AndReturnColumns() call") - } - r.results = append(r.results, values) - return r -} - -type rows struct { - r Rows - closed bool -} - -// Columns implements the [driver.Rows] interface. -func (r *rows) Columns() []string { - return r.r.columns -} - -// Close implements the [driver.Rows] interface. -func (r *rows) Close() error { - r.closed = true - return nil -} - -// Next implements the [driver.Rows] interface. -func (r *rows) Next(dest []driver.Value) error { - if r.closed { - return errors.New("rows object was closed") - } - if len(r.r.results) == 0 { - return io.EOF - } - for idx, value := range r.r.results[0] { - dest[idx] = value - } - r.r.results = r.r.results[1:] - return nil -} diff --git a/internal/mock/mock.go b/internal/mock/mock.go new file mode 100644 index 0000000..d3358c4 --- /dev/null +++ b/internal/mock/mock.go @@ -0,0 +1,304 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky +// SPDX-License-Identifier: Apache-2.0 + +package mock + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" + "io" + "reflect" + "slices" + "strings" +) + +//////////////////////////////////////////////////////////////////////////////// +// type Driver + +// Driver is a mock SQL driver that only accepts queries that were preannounced. +type Driver struct { + responseSetsByQuery map[string]*ResponseSet +} + +// assert that interface is implemented +var _ driver.Connector = &Driver{} + +// NewDriver instantiates a new driver. +// The result returns [driver.Connector] and can be given to [sql.OpenDB]. +func NewDriver() *Driver { + return &Driver{ + responseSetsByQuery: make(map[string]*ResponseSet), + } +} + +// Connect implements the [driver.Connector] interface. +func (d *Driver) Connect(ctx context.Context) (driver.Conn, error) { + return &connection{d: d}, nil +} + +// Driver implements the [driver.Connector] interface. +func (d *Driver) Driver() driver.Driver { + // Not needed. Implementing the Driver interface would only be necessary if + // we wanted to use sql.Open() instead of sql.OpenDB(), or if we wanted to + // use sql.DB.Driver(). + panic("unimplemented") +} + +// ForQuery tells the driver to expect the given query string to be sent soon. +// The return value can be used to plan what to return when the query is actually executed. +func (d *Driver) ForQuery(query string) *ResponseSet { + if d.responseSetsByQuery[query] == nil { + d.responseSetsByQuery[query] = &ResponseSet{} + } + return d.responseSetsByQuery[query] +} + +//////////////////////////////////////////////////////////////////////////////// +// type ResponseSet + +// ResponseSet is a set of mock responses for a query sent to type [Driver]. +type ResponseSet struct { + expectedExecs []expectation[Result] + expectedQueries []expectation[Rows] +} + +type expectation[T any] struct { + args []driver.Value + output *T +} + +func newExpectation[T any](args []any) expectation[T] { + e := expectation[T]{ + args: make([]driver.Value, len(args)), + output: new(T), + } + for idx, arg := range args { + var err error + e.args[idx], err = driver.DefaultParameterConverter.ConvertValue(arg) + if err != nil { + panic(fmt.Sprintf("could not convert value %#v into driver.Value: %s", arg, err.Error())) + } + } + return e +} + +// ExpectExecWithArgs plans a response to an Exec() call. +func (rs *ResponseSet) ExpectExecWithArgs(args ...any) *Result { + e := newExpectation[Result](args) + rs.expectedExecs = append(rs.expectedExecs, e) + return e.output +} + +// ExpectQueryWithArgs plans a response to a Query() or QueryRows() call. +func (rs *ResponseSet) ExpectQueryWithArgs(args ...any) *Rows { + e := newExpectation[Rows](args) + rs.expectedQueries = append(rs.expectedQueries, e) + return e.output +} + +//////////////////////////////////////////////////////////////////////////////// +// type connection + +type connection struct { + d *Driver + closed bool +} + +// Prepare implements the [driver.Conn] interface. +func (c *connection) Prepare(query string) (driver.Stmt, error) { + rs := c.d.responseSetsByQuery[query] + if rs == nil { + return nil, fmt.Errorf("unexpected query: %s", query) + } + return &statement{c: c, query: query, rs: rs}, nil +} + +// Close implements the [driver.Conn] interface. +func (c *connection) Close() error { + c.closed = true + return nil +} + +// Begin implements the [driver.Conn] interface. +func (c *connection) Begin() (driver.Tx, error) { + return transaction{}, nil +} + +//////////////////////////////////////////////////////////////////////////////// +// type transaction + +type transaction struct{} + +// Commit implements the [driver.Tx] interface. +func (t transaction) Commit() error { + return nil // unused +} + +// Rollback implements the [driver.Tx] interface. +func (t transaction) Rollback() error { + return nil // unused +} + +//////////////////////////////////////////////////////////////////////////////// +// type statement + +type statement struct { + c *connection + query string + rs *ResponseSet + closed bool +} + +// Close implements the [driver.Stmt] interface. +func (s *statement) Close() error { + return nil +} + +// NumInput implements the [driver.Stmt] interface. +func (s *statement) NumInput() int { + return strings.Count(s.query, "?") // NOTE: extremely crude, but does the job for us +} + +// Exec implements the [driver.Stmt] interface. +func (s *statement) Exec(args []driver.Value) (driver.Result, error) { + if s.closed { + return nil, errors.New("statement was closed") + } + if s.c.closed { + return nil, errors.New("connection was closed") + } + for idx, e := range s.rs.expectedExecs { + if reflect.DeepEqual(e.args, args) { + s.rs.expectedExecs = slices.Delete(s.rs.expectedExecs, idx, idx+1) + return result{r: *e.output}, nil + } + } + return nil, fmt.Errorf("unexpected arguments for query %q: %#v", s.query, args) +} + +// Query implements the [driver.Stmt] interface. +func (s *statement) Query(args []driver.Value) (driver.Rows, error) { + if s.closed { + return nil, errors.New("statement was closed") + } + if s.c.closed { + return nil, errors.New("connection was closed") + } + for idx, e := range s.rs.expectedQueries { + if reflect.DeepEqual(e.args, args) { + s.rs.expectedQueries = slices.Delete(s.rs.expectedQueries, idx, idx+1) + return &rows{r: *e.output}, nil + } + } + return nil, fmt.Errorf("unexpected arguments for query %q: %#v", s.query, args) +} + +/////////////////////////////////////////////////////////////////////////////////////////// +// type Result + +// Result is a mock response for an Exec() call. +// It is constructed by [ResponseSet.ExpectExec]. +type Result struct { + lastInsertId *int64 + rowsAffected *int64 +} + +// AndReturnLastInsertId configures a mock LastInsertId() value for this Result. +// Returns the same Result instance to allow chaining additional method calls. +func (r *Result) AndReturnLastInsertId(id int64) *Result { + r.lastInsertId = &id + return r +} + +// AndReturnRowsAffected configures a mock RowsAffected() value for this Result. +// Returns the same Result instance to allow chaining additional method calls. +func (r *Result) AndReturnRowsAffected(count int64) *Result { + r.rowsAffected = &count + return r +} + +type result struct { + r Result +} + +// LastInsertId implements the [driver.Result] interface. +func (r result) LastInsertId() (int64, error) { + if r.r.lastInsertId == nil { + return 0, errors.New("AndReturnLastInsertId() was not called for this Result") + } + return *r.r.lastInsertId, nil +} + +// RowsAffected implements the [driver.Result] interface. +func (r result) RowsAffected() (int64, error) { + if r.r.rowsAffected == nil { + return 0, errors.New("AndReturnRowsAffected() was not called for this Result") + } + return *r.r.rowsAffected, nil +} + +// ///////////////////////////////////////////////////////////////////////////////////////// +// type Rows + +// Rows is a mock response for a Query() or QueryRow() call. +// It is constructed by [ResponseSet.ExpectQuery]. +type Rows struct { + columns []string + results [][]any +} + +// AndReturnColumns configures the set of column names that will be returend by this query. +// Returns the same Result instance to allow chaining additional method calls. +func (r *Rows) AndReturnColumns(columns ...string) *Rows { + if len(r.columns) > 0 { + panic("AndReturnColumns() called multiple times for the same Rows object") + } + r.columns = columns + return r +} + +// WithRow adds a row to the result set that will be returned by this query. +// This may only be called after AndReturnColumns(). +func (r *Rows) WithRow(values ...any) *Rows { + if len(r.columns) == 0 { + panic("AndReturnColumns() has not been called for this Rows object yet") + } + if len(r.columns) != len(values) { + panic("WithRow() must be called with the same number of args as the preceding AndReturnColumns() call") + } + r.results = append(r.results, values) + return r +} + +type rows struct { + r Rows + closed bool +} + +// Columns implements the [driver.Rows] interface. +func (r *rows) Columns() []string { + return r.r.columns +} + +// Close implements the [driver.Rows] interface. +func (r *rows) Close() error { + r.closed = true + return nil +} + +// Next implements the [driver.Rows] interface. +func (r *rows) Next(dest []driver.Value) error { + if r.closed { + return errors.New("rows object was closed") + } + if len(r.r.results) == 0 { + return io.EOF + } + for idx, value := range r.r.results[0] { + dest[idx] = value + } + r.r.results = r.r.results[1:] + return nil +} diff --git a/internal/plan.go b/internal/plan.go deleted file mode 100644 index b57b8dd..0000000 --- a/internal/plan.go +++ /dev/null @@ -1,302 +0,0 @@ -// SPDX-FileCopyrightText: 2026 Stefan Majewsky -// SPDX-License-Identifier: Apache-2.0 - -package internal - -import ( - "errors" - "fmt" - "reflect" - "slices" - "strings" -) - -// Plan holds all information that we can derive from reflecting on a given type. -// The queries held within are only valid within the context of a given SQL dialect. -type Plan struct { - TypeName string // for use in error messages - TableName string // from info.TableNameIs marker (if any) - AllColumnNames []string // in order of struct fields - PrimaryKeyColumnNames []string // from info.PrimaryKeyIs marker (if any) - AutoColumnNames []string // subset of AllColumnNames where field has `,auto` marker - - // Argument for reflect.Value.FieldByIndex() for each column name. - IndexByColumnName map[string][]int - - // In dialects with UsesLastInsertID() == true, whether the ID column must be written with reflect.Value.SetInt() or reflect.Value.SetUint(). - FillIDWithSetUint bool - FillIDWithSetInt bool - - // Planned queries. - Select PlannedQuery // only `SELECT ... FROM ... WHERE `; user supplies the rest during Select{,One}Where() - Insert PlannedQuery - Update PlannedQuery - Delete PlannedQuery -} - -// PlannedQuery appears in type Plan. -type PlannedQuery struct { - // Empty if the respective query type is not supported by this Plan for lack of the required marker types. - Query string - // Arguments for reflect.Value.FieldByIndex() in the correct order for the query arguments of the above query. - ArgumentIndexes [][]int - // Arguments for reflect.Value.FieldByIndex() in the correct order for the Scan() arguments of the above query. - ScanIndexes [][]int -} - -// PlanOpts holds additional arguments to BuildPlan(). -type PlanOpts struct { - TableName string - PrimaryKeyColumnNames []string -} - -// BuildPlan creates a new plan for the given struct type. -func BuildPlan(t reflect.Type, dialect Dialect, opts PlanOpts) (Plan, error) { - p, err := buildPlan(t, dialect, opts) - if err != nil { - return Plan{}, fmt.Errorf("cannot use type %s.%s for queries: %w", t.PkgPath(), t.Name(), err) - } - return p, nil -} - -func buildPlan(t reflect.Type, dialect Dialect, opts PlanOpts) (Plan, error) { - if t.Kind() != reflect.Struct { - return Plan{}, fmt.Errorf("expected struct type, but got kind %s", t.Kind().String()) - } - - var p = Plan{ - TypeName: t.Name(), - TableName: opts.TableName, - PrimaryKeyColumnNames: opts.PrimaryKeyColumnNames, - IndexByColumnName: make(map[string][]int), - } - - // discover addressable fields in this type, - // collect information from markers and tags - for _, field := range reflect.VisibleFields(t) { - tags := strings.Split(strings.TrimSpace(field.Tag.Get("db")), ",") - - switch { - case field.PkgPath != "": - // ignore unexported fields (otherwise reflect.Value.Interface() on the field would panic) - continue - case field.Anonymous && field.Type.Kind() == reflect.Struct: - // for embedded struct fields, only consider their members, not the type itself, as a potential column - continue - default: - columnName, extraTags := tags[0], tags[1:] - if columnName == "-" { - continue - } - if columnName == "" { - columnName = field.Name - } - if otherIndex := p.IndexByColumnName[columnName]; otherIndex != nil { - return Plan{}, fmt.Errorf( - "duplicate tag `db:%q` on field index %v, but also on field index %v", - columnName, otherIndex, field.Index, - ) - } - p.IndexByColumnName[columnName] = field.Index - p.AllColumnNames = append(p.AllColumnNames, columnName) - - for _, tag := range extraTags { - switch tag { - case "auto": - p.AutoColumnNames = append(p.AutoColumnNames, columnName) - default: - return Plan{}, fmt.Errorf("unknown tag `db:%q` on field index %v", ","+tag, field.Index) - } - } - } - } - - // validation: defining a primary key only makes sense for records that map onto a single table - if len(p.PrimaryKeyColumnNames) > 0 && p.TableName == "" { - return Plan{}, errors.New("cannot declare a primary key without also providing the TableNameIs option") - } - - // validation: oblast.PrimaryKeyInfo must refer to columns that exist - for _, columnName := range p.PrimaryKeyColumnNames { - _, ok := p.IndexByColumnName[columnName] - if !ok { - return Plan{}, fmt.Errorf("no field has tag `db:%q`, but a field of this name was declared in the primary key", columnName) - } - } - - // validation: LastInsertID() only works if at most one column is auto-filled, and if that column holds an integer type - if dialect.UsesLastInsertID() { - switch len(p.AutoColumnNames) { - case 0: - // nothing to check - case 1: - columnName := p.AutoColumnNames[0] - field := t.FieldByIndex(p.IndexByColumnName[columnName]) - switch field.Type.Kind() { //nolint:exhaustive // false positive - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - p.FillIDWithSetInt = true - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - p.FillIDWithSetUint = true - default: - return Plan{}, fmt.Errorf( - "column is marked as auto-filled (%s), but this SQL dialect only supports auto-filling struct fields with integer types", - strings.Join(p.AutoColumnNames, ", "), - ) - } - default: - return Plan{}, fmt.Errorf( - "multiple columns are marked as auto-filled (%s), but this SQL dialect only supports at most one per table", - strings.Join(p.AutoColumnNames, ", "), - ) - } - } - - // prepare query strings - p.Select = p.buildSelectQueryIfPossible(dialect) - p.Insert = p.buildInsertQueryIfPossible(dialect) - p.Update = p.buildUpdateQueryIfPossible(dialect) - p.Delete = p.buildDeleteQueryIfPossible(dialect) - - return p, nil -} - -func (p Plan) getNonAutoColumnNames() []string { - result := make([]string, 0, len(p.AllColumnNames)-len(p.AutoColumnNames)) - for _, columnName := range p.AllColumnNames { - if !slices.Contains(p.AutoColumnNames, columnName) { - result = append(result, columnName) - } - } - return result -} - -func (p Plan) getNonPrimaryKeyColumnNames() []string { - result := make([]string, 0, len(p.AllColumnNames)-len(p.PrimaryKeyColumnNames)) - for _, columnName := range p.AllColumnNames { - if !slices.Contains(p.PrimaryKeyColumnNames, columnName) { - result = append(result, columnName) - } - } - return result -} - -func (p Plan) buildSelectQueryIfPossible(dialect Dialect) PlannedQuery { - if p.TableName == "" { - return PlannedQuery{Query: ""} - } - - var ( - scanIndexes = make([][]int, len(p.AllColumnNames)) - quotedColumnNames = make([]string, len(p.AllColumnNames)) - ) - for idx, columnName := range p.AllColumnNames { - scanIndexes[idx] = p.IndexByColumnName[columnName] - quotedColumnNames[idx] = dialect.QuoteIdentifier(columnName) - } - - query := fmt.Sprintf( - `SELECT %s FROM %s WHERE `, - strings.Join(quotedColumnNames, ", "), - dialect.QuoteIdentifier(p.TableName), - ) - return PlannedQuery{query, nil, scanIndexes} -} - -func (p Plan) buildInsertQueryIfPossible(dialect Dialect) PlannedQuery { - if p.TableName == "" || len(p.AllColumnNames) == 0 { - return PlannedQuery{Query: ""} - } - nonAutoColumnNames := p.getNonAutoColumnNames() - if len(nonAutoColumnNames) == 0 { - return PlannedQuery{Query: ""} - } - - var ( - argumentIndexes = make([][]int, len(nonAutoColumnNames)) - scanIndexes [][]int - quotedColumnNames = make([]string, len(nonAutoColumnNames)) - quotedPlaceholders = make([]string, len(nonAutoColumnNames)) - ) - for idx, columnName := range nonAutoColumnNames { - argumentIndexes[idx] = p.IndexByColumnName[columnName] - quotedColumnNames[idx] = dialect.QuoteIdentifier(columnName) - quotedPlaceholders[idx] = dialect.Placeholder(idx) - } - if len(p.AutoColumnNames) > 0 { - // NOTE: This is filled even if dialect.UsesLastInsertID() is false. - // We need this index to find the right value on which to run SetInt() or SetUint(). - scanIndexes = make([][]int, len(p.AutoColumnNames)) - for idx, columnName := range p.AutoColumnNames { - scanIndexes[idx] = p.IndexByColumnName[columnName] - } - } - - query := fmt.Sprintf( - `INSERT INTO %s (%s) VALUES (%s)`, - dialect.QuoteIdentifier(p.TableName), - strings.Join(quotedColumnNames, ", "), - strings.Join(quotedPlaceholders, ", "), - ) - if len(p.AutoColumnNames) > 0 { - query += dialect.InsertSuffixForAutoColumns(p.AutoColumnNames) - } - return PlannedQuery{query, argumentIndexes, scanIndexes} -} - -func (p Plan) buildUpdateQueryIfPossible(dialect Dialect) PlannedQuery { - if p.TableName == "" || len(p.PrimaryKeyColumnNames) == 0 { - return PlannedQuery{Query: ""} - } - nonPrimaryKeyColumnNames := p.getNonPrimaryKeyColumnNames() - if len(nonPrimaryKeyColumnNames) == 0 { - return PlannedQuery{Query: ""} - } - - var ( - setArgumentIndexes = make([][]int, len(nonPrimaryKeyColumnNames)) - setClauses = make([]string, len(nonPrimaryKeyColumnNames)) - ) - for idx, columnName := range nonPrimaryKeyColumnNames { - setArgumentIndexes[idx] = p.IndexByColumnName[columnName] - setClauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx)) - } - - var ( - whereArgumentIndexes = make([][]int, len(p.PrimaryKeyColumnNames)) - whereClauses = make([]string, len(p.PrimaryKeyColumnNames)) - ) - for idx, columnName := range p.PrimaryKeyColumnNames { - whereArgumentIndexes[idx] = p.IndexByColumnName[columnName] - whereClauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx+len(setClauses))) - } - - query := fmt.Sprintf( - `UPDATE %s SET %s WHERE %s`, - dialect.QuoteIdentifier(p.TableName), - strings.Join(setClauses, ", "), - strings.Join(whereClauses, " AND "), - ) - return PlannedQuery{query, slices.Concat(setArgumentIndexes, whereArgumentIndexes), nil} -} - -func (p Plan) buildDeleteQueryIfPossible(dialect Dialect) PlannedQuery { - if p.TableName == "" || len(p.PrimaryKeyColumnNames) == 0 { - return PlannedQuery{Query: ""} - } - - var ( - argumentIndexes = make([][]int, len(p.PrimaryKeyColumnNames)) - clauses = make([]string, len(p.PrimaryKeyColumnNames)) - ) - for idx, columnName := range p.PrimaryKeyColumnNames { - argumentIndexes[idx] = p.IndexByColumnName[columnName] - clauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx)) - } - - query := fmt.Sprintf( - `DELETE FROM %s WHERE %s`, - dialect.QuoteIdentifier(p.TableName), - strings.Join(clauses, " AND "), - ) - return PlannedQuery{query, argumentIndexes, nil} -} diff --git a/internal/plan_test.go b/internal/plan_test.go deleted file mode 100644 index e692556..0000000 --- a/internal/plan_test.go +++ /dev/null @@ -1,277 +0,0 @@ -// SPDX-FileCopyrightText: 2026 Stefan Majewsky -// SPDX-License-Identifier: Apache-2.0 - -package internal_test - -import ( - "reflect" - "testing" - "time" - - "go.xyrillian.de/oblast/internal" - "go.xyrillian.de/oblast/internal/assert" -) - -func TestPlanFieldTraversal(t *testing.T) { - type Timestamps struct { - CreatedAt time.Time `db:"created_at"` - UpdatedAt *time.Time `db:"updated_at"` - } - type yetMoreTimestamps struct { - DeletedAt *time.Time `db:"deleted_at"` - } - type Log struct { - ID int64 `db:"id,auto"` - Message string - private1 bool `db:"private1"` //nolint:unused - Ignored any `db:"-"` - Timestamps - yetMoreTimestamps - } - - // check that the plan for Log: - // 1. has no IndexByColumnName entries for marker types - // 2. uses the field name as a column name for "Message" - // 3. ignores "private1" because it cannot be written through reflection - // 4. ignores "Ignored" because its column name is "-" - // 5. traverses into "Timestamps" and includes its fields as well - // 6. traverses into "yetMoreTimestamps" as well (despite the extra pointer and the type being private) - // 7. recognizes "id" as an autofilled column - plan, err := internal.BuildPlan(reflect.TypeFor[Log](), internal.PostgresDialect{}, internal.PlanOpts{ - TableName: "log_entries", - PrimaryKeyColumnNames: []string{"id"}, - }) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.TableName, "log_entries") - assert.DeepEqual(t, plan.AllColumnNames, []string{"id", "Message", "created_at", "updated_at", "deleted_at"}) - assert.DeepEqual(t, plan.PrimaryKeyColumnNames, []string{"id"}) - assert.DeepEqual(t, plan.AutoColumnNames, []string{"id"}) - assert.DeepEqual(t, plan.IndexByColumnName, map[string][]int{ - "id": {0}, - "Message": {1}, - "created_at": {4, 0}, - "updated_at": {4, 1}, - "deleted_at": {5, 0}, - }) -} - -// TODO: test that, during Select(), assignment into embedded fields with pointer-to-struct type works (docs say that this might panic if we do not allocate into the pointer first) - -func TestQueryConstructionBasic(t *testing.T) { - type record struct { - ID int64 `db:",auto"` - Description string - CreatedAt time.Time - } - opts := internal.PlanOpts{ - TableName: "basic_records", - PrimaryKeyColumnNames: []string{"ID"}, - } - - t.Run("PostgresDialect", func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[record](), internal.PostgresDialect{}, opts) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES ($1, $2) RETURNING "ID"`) - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) - assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}}) - assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = $1, "CreatedAt" = $2 WHERE "ID" = $3`) - assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, `DELETE FROM "basic_records" WHERE "ID" = $1`) - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}}) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - }) - - t.Run("SqliteDialect", func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[record](), internal.SqliteDialect{}, opts) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES (?, ?)`) - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) - assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}}) - assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = ?, "CreatedAt" = ? WHERE "ID" = ?`) - assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, `DELETE FROM "basic_records" WHERE "ID" = ?`) - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}}) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - }) -} - -func TestQueryConstructionWithoutPrimaryKey(t *testing.T) { - type relation struct { - FooID int64 `db:"foo_id"` - BarID int64 `db:"bar_id"` - } - opts := internal.PlanOpts{ - TableName: "foo_bar_relations", - } - - t.Run("PostgresDialect", func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[relation](), internal.PostgresDialect{}, opts) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES ($1, $2)`) - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) - assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) - assert.Equal(t, plan.Update.Query, "") - assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, "") - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - }) - - t.Run("SqliteDialect", func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[relation](), internal.SqliteDialect{}, opts) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?)`) - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) - assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) - assert.Equal(t, plan.Update.Query, "") - assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, "") - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - }) -} - -func TestQueryConstructionImpossble(t *testing.T) { - type unstructuredData struct { - Foo int - Bar string - } - opts := internal.PlanOpts{} - - testWith := func(dialect internal.Dialect) func(*testing.T) { - return func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[unstructuredData](), dialect, opts) - if err != nil { - t.Error(err) - } - - assert.Equal(t, plan.Select.Query, "") - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, nil) - assert.Equal(t, plan.Insert.Query, "") - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) - assert.Equal(t, plan.Update.Query, "") - assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, "") - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - } - } - - t.Run("PostgresDialect", testWith(internal.PostgresDialect{})) - t.Run("SqliteDialect", testWith(internal.SqliteDialect{})) -} - -func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) { - type record struct { - GroupID int64 `db:"group_id"` - Name string `db:"name"` - CreatedAt time.Time `db:"created_at"` - } - opts := internal.PlanOpts{ - TableName: "complex_records", - PrimaryKeyColumnNames: []string{"group_id", "name"}, - } - - t.Run("PostgresDialect", func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[record](), internal.PostgresDialect{}, opts) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES ($1, $2, $3)`) - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) - assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) - assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = $1 WHERE "group_id" = $2 AND "name" = $3`) - assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}}) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, `DELETE FROM "complex_records" WHERE "group_id" = $1 AND "name" = $2`) - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}, {1}}) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - }) - - t.Run("SqliteDialect", func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[record](), internal.SqliteDialect{}, opts) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES (?, ?, ?)`) - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) - assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) - assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = ? WHERE "group_id" = ? AND "name" = ?`) - assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}}) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, `DELETE FROM "complex_records" WHERE "group_id" = ? AND "name" = ?`) - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}, {1}}) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - }) -} - -func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) { - type record struct { - ID int64 `db:"id,auto"` - Name string `db:"name"` - CreatedAt time.Time `db:"created_at,auto"` - } - opts := internal.PlanOpts{ - TableName: "autogenerated_records", - PrimaryKeyColumnNames: []string{"id"}, - } - - t.Run("PostgresDialect", func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[record](), internal.PostgresDialect{}, opts) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.Select.Query, `SELECT "id", "name", "created_at" FROM "autogenerated_records" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "autogenerated_records" ("name") VALUES ($1) RETURNING "id", "created_at"`) - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}}) - assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}, {2}}) - assert.Equal(t, plan.Update.Query, `UPDATE "autogenerated_records" SET "name" = $1, "created_at" = $2 WHERE "id" = $3`) - assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, `DELETE FROM "autogenerated_records" WHERE "id" = $1`) - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}}) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - }) - - t.Run("SqliteDialect", func(t *testing.T) { - _, err := internal.BuildPlan(reflect.TypeFor[record](), internal.SqliteDialect{}, opts) - assert.Equal(t, err.Error(), `cannot use type go.xyrillian.de/oblast/internal_test.record for queries: multiple columns are marked as auto-filled (id, created_at), but this SQL dialect only supports at most one per table`) - }) -} -- cgit v1.2.3