From d75031ffd1667c330ccc281ea330503eaeaea88a Mon Sep 17 00:00:00 2001 From: Stefan Majewsky Date: Tue, 14 Apr 2026 00:50:20 +0200 Subject: fold package internal into package oblast --- dialect.go | 30 ++++- internal/dialect.go | 41 ------- internal/mock/driver.go | 304 ------------------------------------------------ internal/mock/mock.go | 304 ++++++++++++++++++++++++++++++++++++++++++++++++ internal/plan.go | 302 ----------------------------------------------- internal/plan_test.go | 277 ------------------------------------------- oblast.go | 19 +-- plan.go | 294 ++++++++++++++++++++++++++++++++++++++++++++++ plan_test.go | 278 +++++++++++++++++++++++++++++++++++++++++++ select.go | 8 +- 10 files changed, 917 insertions(+), 940 deletions(-) delete mode 100644 internal/dialect.go delete mode 100644 internal/mock/driver.go create mode 100644 internal/mock/mock.go delete mode 100644 internal/plan.go delete mode 100644 internal/plan_test.go create mode 100644 plan.go create mode 100644 plan_test.go diff --git a/dialect.go b/dialect.go index acbb160..a2827e2 100644 --- a/dialect.go +++ b/dialect.go @@ -3,7 +3,10 @@ package oblast -import "go.xyrillian.de/oblast/internal" +import ( + "strconv" + "strings" +) // Dialect accounts for differences between different SQL dialects // that are relevant to query generation within Oblast. @@ -38,10 +41,31 @@ type Dialect interface { // PostgresDialect is the dialect of PostgreSQL databases. func PostgresDialect() Dialect { - return internal.PostgresDialect{} + return postgresDialect{} +} + +type postgresDialect struct{} + +func (postgresDialect) Placeholder(i int) string { return "$" + strconv.Itoa(i+1) } +func (postgresDialect) QuoteIdentifier(name string) string { return `"` + name + `"` } +func (postgresDialect) UsesLastInsertID() bool { return false } + +func (p postgresDialect) InsertSuffixForAutoColumns(columns []string) string { + quotedColumns := make([]string, len(columns)) + for idx, name := range columns { + quotedColumns[idx] = p.QuoteIdentifier(name) + } + return ` RETURNING ` + strings.Join(quotedColumns, ", ") } // SqliteDialect is the dialect of SQLite databases. func SqliteDialect() Dialect { - return internal.SqliteDialect{} + return sqliteDialect{} } + +type sqliteDialect struct{} + +func (sqliteDialect) Placeholder(_ int) string { return "?" } +func (sqliteDialect) QuoteIdentifier(name string) string { return `"` + name + `"` } +func (sqliteDialect) UsesLastInsertID() bool { return true } +func (sqliteDialect) InsertSuffixForAutoColumns(columns []string) string { return "" } diff --git a/internal/dialect.go b/internal/dialect.go deleted file mode 100644 index e6db5b8..0000000 --- a/internal/dialect.go +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-FileCopyrightText: 2026 Stefan Majewsky -// SPDX-License-Identifier: Apache-2.0 - -package internal - -import ( - "strconv" - "strings" -) - -// Dialect is a copy of the interface of the same name in package oblast. -// We cannot refer to that interface within this package because that would constitute a cyclic dependency. -type Dialect interface { - Placeholder(i int) string - QuoteIdentifier(name string) string - UsesLastInsertID() bool - InsertSuffixForAutoColumns(columns []string) string -} - -// PostgresDialect is the dialect of PostgreSQL databases. -type PostgresDialect struct{} - -func (PostgresDialect) Placeholder(i int) string { return "$" + strconv.Itoa(i+1) } -func (PostgresDialect) QuoteIdentifier(name string) string { return `"` + name + `"` } -func (PostgresDialect) UsesLastInsertID() bool { return false } - -func (p PostgresDialect) InsertSuffixForAutoColumns(columns []string) string { - quotedColumns := make([]string, len(columns)) - for idx, name := range columns { - quotedColumns[idx] = p.QuoteIdentifier(name) - } - return ` RETURNING ` + strings.Join(quotedColumns, ", ") -} - -// SqliteDialect is the dialect of SQLite databases. -type SqliteDialect struct{} - -func (SqliteDialect) Placeholder(_ int) string { return "?" } -func (SqliteDialect) QuoteIdentifier(name string) string { return `"` + name + `"` } -func (SqliteDialect) UsesLastInsertID() bool { return true } -func (SqliteDialect) InsertSuffixForAutoColumns(columns []string) string { return "" } diff --git a/internal/mock/driver.go b/internal/mock/driver.go deleted file mode 100644 index d3358c4..0000000 --- a/internal/mock/driver.go +++ /dev/null @@ -1,304 +0,0 @@ -// SPDX-FileCopyrightText: 2026 Stefan Majewsky -// SPDX-License-Identifier: Apache-2.0 - -package mock - -import ( - "context" - "database/sql/driver" - "errors" - "fmt" - "io" - "reflect" - "slices" - "strings" -) - -//////////////////////////////////////////////////////////////////////////////// -// type Driver - -// Driver is a mock SQL driver that only accepts queries that were preannounced. -type Driver struct { - responseSetsByQuery map[string]*ResponseSet -} - -// assert that interface is implemented -var _ driver.Connector = &Driver{} - -// NewDriver instantiates a new driver. -// The result returns [driver.Connector] and can be given to [sql.OpenDB]. -func NewDriver() *Driver { - return &Driver{ - responseSetsByQuery: make(map[string]*ResponseSet), - } -} - -// Connect implements the [driver.Connector] interface. -func (d *Driver) Connect(ctx context.Context) (driver.Conn, error) { - return &connection{d: d}, nil -} - -// Driver implements the [driver.Connector] interface. -func (d *Driver) Driver() driver.Driver { - // Not needed. Implementing the Driver interface would only be necessary if - // we wanted to use sql.Open() instead of sql.OpenDB(), or if we wanted to - // use sql.DB.Driver(). - panic("unimplemented") -} - -// ForQuery tells the driver to expect the given query string to be sent soon. -// The return value can be used to plan what to return when the query is actually executed. -func (d *Driver) ForQuery(query string) *ResponseSet { - if d.responseSetsByQuery[query] == nil { - d.responseSetsByQuery[query] = &ResponseSet{} - } - return d.responseSetsByQuery[query] -} - -//////////////////////////////////////////////////////////////////////////////// -// type ResponseSet - -// ResponseSet is a set of mock responses for a query sent to type [Driver]. -type ResponseSet struct { - expectedExecs []expectation[Result] - expectedQueries []expectation[Rows] -} - -type expectation[T any] struct { - args []driver.Value - output *T -} - -func newExpectation[T any](args []any) expectation[T] { - e := expectation[T]{ - args: make([]driver.Value, len(args)), - output: new(T), - } - for idx, arg := range args { - var err error - e.args[idx], err = driver.DefaultParameterConverter.ConvertValue(arg) - if err != nil { - panic(fmt.Sprintf("could not convert value %#v into driver.Value: %s", arg, err.Error())) - } - } - return e -} - -// ExpectExecWithArgs plans a response to an Exec() call. -func (rs *ResponseSet) ExpectExecWithArgs(args ...any) *Result { - e := newExpectation[Result](args) - rs.expectedExecs = append(rs.expectedExecs, e) - return e.output -} - -// ExpectQueryWithArgs plans a response to a Query() or QueryRows() call. -func (rs *ResponseSet) ExpectQueryWithArgs(args ...any) *Rows { - e := newExpectation[Rows](args) - rs.expectedQueries = append(rs.expectedQueries, e) - return e.output -} - -//////////////////////////////////////////////////////////////////////////////// -// type connection - -type connection struct { - d *Driver - closed bool -} - -// Prepare implements the [driver.Conn] interface. -func (c *connection) Prepare(query string) (driver.Stmt, error) { - rs := c.d.responseSetsByQuery[query] - if rs == nil { - return nil, fmt.Errorf("unexpected query: %s", query) - } - return &statement{c: c, query: query, rs: rs}, nil -} - -// Close implements the [driver.Conn] interface. -func (c *connection) Close() error { - c.closed = true - return nil -} - -// Begin implements the [driver.Conn] interface. -func (c *connection) Begin() (driver.Tx, error) { - return transaction{}, nil -} - -//////////////////////////////////////////////////////////////////////////////// -// type transaction - -type transaction struct{} - -// Commit implements the [driver.Tx] interface. -func (t transaction) Commit() error { - return nil // unused -} - -// Rollback implements the [driver.Tx] interface. -func (t transaction) Rollback() error { - return nil // unused -} - -//////////////////////////////////////////////////////////////////////////////// -// type statement - -type statement struct { - c *connection - query string - rs *ResponseSet - closed bool -} - -// Close implements the [driver.Stmt] interface. -func (s *statement) Close() error { - return nil -} - -// NumInput implements the [driver.Stmt] interface. -func (s *statement) NumInput() int { - return strings.Count(s.query, "?") // NOTE: extremely crude, but does the job for us -} - -// Exec implements the [driver.Stmt] interface. -func (s *statement) Exec(args []driver.Value) (driver.Result, error) { - if s.closed { - return nil, errors.New("statement was closed") - } - if s.c.closed { - return nil, errors.New("connection was closed") - } - for idx, e := range s.rs.expectedExecs { - if reflect.DeepEqual(e.args, args) { - s.rs.expectedExecs = slices.Delete(s.rs.expectedExecs, idx, idx+1) - return result{r: *e.output}, nil - } - } - return nil, fmt.Errorf("unexpected arguments for query %q: %#v", s.query, args) -} - -// Query implements the [driver.Stmt] interface. -func (s *statement) Query(args []driver.Value) (driver.Rows, error) { - if s.closed { - return nil, errors.New("statement was closed") - } - if s.c.closed { - return nil, errors.New("connection was closed") - } - for idx, e := range s.rs.expectedQueries { - if reflect.DeepEqual(e.args, args) { - s.rs.expectedQueries = slices.Delete(s.rs.expectedQueries, idx, idx+1) - return &rows{r: *e.output}, nil - } - } - return nil, fmt.Errorf("unexpected arguments for query %q: %#v", s.query, args) -} - -/////////////////////////////////////////////////////////////////////////////////////////// -// type Result - -// Result is a mock response for an Exec() call. -// It is constructed by [ResponseSet.ExpectExec]. -type Result struct { - lastInsertId *int64 - rowsAffected *int64 -} - -// AndReturnLastInsertId configures a mock LastInsertId() value for this Result. -// Returns the same Result instance to allow chaining additional method calls. -func (r *Result) AndReturnLastInsertId(id int64) *Result { - r.lastInsertId = &id - return r -} - -// AndReturnRowsAffected configures a mock RowsAffected() value for this Result. -// Returns the same Result instance to allow chaining additional method calls. -func (r *Result) AndReturnRowsAffected(count int64) *Result { - r.rowsAffected = &count - return r -} - -type result struct { - r Result -} - -// LastInsertId implements the [driver.Result] interface. -func (r result) LastInsertId() (int64, error) { - if r.r.lastInsertId == nil { - return 0, errors.New("AndReturnLastInsertId() was not called for this Result") - } - return *r.r.lastInsertId, nil -} - -// RowsAffected implements the [driver.Result] interface. -func (r result) RowsAffected() (int64, error) { - if r.r.rowsAffected == nil { - return 0, errors.New("AndReturnRowsAffected() was not called for this Result") - } - return *r.r.rowsAffected, nil -} - -// ///////////////////////////////////////////////////////////////////////////////////////// -// type Rows - -// Rows is a mock response for a Query() or QueryRow() call. -// It is constructed by [ResponseSet.ExpectQuery]. -type Rows struct { - columns []string - results [][]any -} - -// AndReturnColumns configures the set of column names that will be returend by this query. -// Returns the same Result instance to allow chaining additional method calls. -func (r *Rows) AndReturnColumns(columns ...string) *Rows { - if len(r.columns) > 0 { - panic("AndReturnColumns() called multiple times for the same Rows object") - } - r.columns = columns - return r -} - -// WithRow adds a row to the result set that will be returned by this query. -// This may only be called after AndReturnColumns(). -func (r *Rows) WithRow(values ...any) *Rows { - if len(r.columns) == 0 { - panic("AndReturnColumns() has not been called for this Rows object yet") - } - if len(r.columns) != len(values) { - panic("WithRow() must be called with the same number of args as the preceding AndReturnColumns() call") - } - r.results = append(r.results, values) - return r -} - -type rows struct { - r Rows - closed bool -} - -// Columns implements the [driver.Rows] interface. -func (r *rows) Columns() []string { - return r.r.columns -} - -// Close implements the [driver.Rows] interface. -func (r *rows) Close() error { - r.closed = true - return nil -} - -// Next implements the [driver.Rows] interface. -func (r *rows) Next(dest []driver.Value) error { - if r.closed { - return errors.New("rows object was closed") - } - if len(r.r.results) == 0 { - return io.EOF - } - for idx, value := range r.r.results[0] { - dest[idx] = value - } - r.r.results = r.r.results[1:] - return nil -} diff --git a/internal/mock/mock.go b/internal/mock/mock.go new file mode 100644 index 0000000..d3358c4 --- /dev/null +++ b/internal/mock/mock.go @@ -0,0 +1,304 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky +// SPDX-License-Identifier: Apache-2.0 + +package mock + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" + "io" + "reflect" + "slices" + "strings" +) + +//////////////////////////////////////////////////////////////////////////////// +// type Driver + +// Driver is a mock SQL driver that only accepts queries that were preannounced. +type Driver struct { + responseSetsByQuery map[string]*ResponseSet +} + +// assert that interface is implemented +var _ driver.Connector = &Driver{} + +// NewDriver instantiates a new driver. +// The result returns [driver.Connector] and can be given to [sql.OpenDB]. +func NewDriver() *Driver { + return &Driver{ + responseSetsByQuery: make(map[string]*ResponseSet), + } +} + +// Connect implements the [driver.Connector] interface. +func (d *Driver) Connect(ctx context.Context) (driver.Conn, error) { + return &connection{d: d}, nil +} + +// Driver implements the [driver.Connector] interface. +func (d *Driver) Driver() driver.Driver { + // Not needed. Implementing the Driver interface would only be necessary if + // we wanted to use sql.Open() instead of sql.OpenDB(), or if we wanted to + // use sql.DB.Driver(). + panic("unimplemented") +} + +// ForQuery tells the driver to expect the given query string to be sent soon. +// The return value can be used to plan what to return when the query is actually executed. +func (d *Driver) ForQuery(query string) *ResponseSet { + if d.responseSetsByQuery[query] == nil { + d.responseSetsByQuery[query] = &ResponseSet{} + } + return d.responseSetsByQuery[query] +} + +//////////////////////////////////////////////////////////////////////////////// +// type ResponseSet + +// ResponseSet is a set of mock responses for a query sent to type [Driver]. +type ResponseSet struct { + expectedExecs []expectation[Result] + expectedQueries []expectation[Rows] +} + +type expectation[T any] struct { + args []driver.Value + output *T +} + +func newExpectation[T any](args []any) expectation[T] { + e := expectation[T]{ + args: make([]driver.Value, len(args)), + output: new(T), + } + for idx, arg := range args { + var err error + e.args[idx], err = driver.DefaultParameterConverter.ConvertValue(arg) + if err != nil { + panic(fmt.Sprintf("could not convert value %#v into driver.Value: %s", arg, err.Error())) + } + } + return e +} + +// ExpectExecWithArgs plans a response to an Exec() call. +func (rs *ResponseSet) ExpectExecWithArgs(args ...any) *Result { + e := newExpectation[Result](args) + rs.expectedExecs = append(rs.expectedExecs, e) + return e.output +} + +// ExpectQueryWithArgs plans a response to a Query() or QueryRows() call. +func (rs *ResponseSet) ExpectQueryWithArgs(args ...any) *Rows { + e := newExpectation[Rows](args) + rs.expectedQueries = append(rs.expectedQueries, e) + return e.output +} + +//////////////////////////////////////////////////////////////////////////////// +// type connection + +type connection struct { + d *Driver + closed bool +} + +// Prepare implements the [driver.Conn] interface. +func (c *connection) Prepare(query string) (driver.Stmt, error) { + rs := c.d.responseSetsByQuery[query] + if rs == nil { + return nil, fmt.Errorf("unexpected query: %s", query) + } + return &statement{c: c, query: query, rs: rs}, nil +} + +// Close implements the [driver.Conn] interface. +func (c *connection) Close() error { + c.closed = true + return nil +} + +// Begin implements the [driver.Conn] interface. +func (c *connection) Begin() (driver.Tx, error) { + return transaction{}, nil +} + +//////////////////////////////////////////////////////////////////////////////// +// type transaction + +type transaction struct{} + +// Commit implements the [driver.Tx] interface. +func (t transaction) Commit() error { + return nil // unused +} + +// Rollback implements the [driver.Tx] interface. +func (t transaction) Rollback() error { + return nil // unused +} + +//////////////////////////////////////////////////////////////////////////////// +// type statement + +type statement struct { + c *connection + query string + rs *ResponseSet + closed bool +} + +// Close implements the [driver.Stmt] interface. +func (s *statement) Close() error { + return nil +} + +// NumInput implements the [driver.Stmt] interface. +func (s *statement) NumInput() int { + return strings.Count(s.query, "?") // NOTE: extremely crude, but does the job for us +} + +// Exec implements the [driver.Stmt] interface. +func (s *statement) Exec(args []driver.Value) (driver.Result, error) { + if s.closed { + return nil, errors.New("statement was closed") + } + if s.c.closed { + return nil, errors.New("connection was closed") + } + for idx, e := range s.rs.expectedExecs { + if reflect.DeepEqual(e.args, args) { + s.rs.expectedExecs = slices.Delete(s.rs.expectedExecs, idx, idx+1) + return result{r: *e.output}, nil + } + } + return nil, fmt.Errorf("unexpected arguments for query %q: %#v", s.query, args) +} + +// Query implements the [driver.Stmt] interface. +func (s *statement) Query(args []driver.Value) (driver.Rows, error) { + if s.closed { + return nil, errors.New("statement was closed") + } + if s.c.closed { + return nil, errors.New("connection was closed") + } + for idx, e := range s.rs.expectedQueries { + if reflect.DeepEqual(e.args, args) { + s.rs.expectedQueries = slices.Delete(s.rs.expectedQueries, idx, idx+1) + return &rows{r: *e.output}, nil + } + } + return nil, fmt.Errorf("unexpected arguments for query %q: %#v", s.query, args) +} + +/////////////////////////////////////////////////////////////////////////////////////////// +// type Result + +// Result is a mock response for an Exec() call. +// It is constructed by [ResponseSet.ExpectExec]. +type Result struct { + lastInsertId *int64 + rowsAffected *int64 +} + +// AndReturnLastInsertId configures a mock LastInsertId() value for this Result. +// Returns the same Result instance to allow chaining additional method calls. +func (r *Result) AndReturnLastInsertId(id int64) *Result { + r.lastInsertId = &id + return r +} + +// AndReturnRowsAffected configures a mock RowsAffected() value for this Result. +// Returns the same Result instance to allow chaining additional method calls. +func (r *Result) AndReturnRowsAffected(count int64) *Result { + r.rowsAffected = &count + return r +} + +type result struct { + r Result +} + +// LastInsertId implements the [driver.Result] interface. +func (r result) LastInsertId() (int64, error) { + if r.r.lastInsertId == nil { + return 0, errors.New("AndReturnLastInsertId() was not called for this Result") + } + return *r.r.lastInsertId, nil +} + +// RowsAffected implements the [driver.Result] interface. +func (r result) RowsAffected() (int64, error) { + if r.r.rowsAffected == nil { + return 0, errors.New("AndReturnRowsAffected() was not called for this Result") + } + return *r.r.rowsAffected, nil +} + +// ///////////////////////////////////////////////////////////////////////////////////////// +// type Rows + +// Rows is a mock response for a Query() or QueryRow() call. +// It is constructed by [ResponseSet.ExpectQuery]. +type Rows struct { + columns []string + results [][]any +} + +// AndReturnColumns configures the set of column names that will be returend by this query. +// Returns the same Result instance to allow chaining additional method calls. +func (r *Rows) AndReturnColumns(columns ...string) *Rows { + if len(r.columns) > 0 { + panic("AndReturnColumns() called multiple times for the same Rows object") + } + r.columns = columns + return r +} + +// WithRow adds a row to the result set that will be returned by this query. +// This may only be called after AndReturnColumns(). +func (r *Rows) WithRow(values ...any) *Rows { + if len(r.columns) == 0 { + panic("AndReturnColumns() has not been called for this Rows object yet") + } + if len(r.columns) != len(values) { + panic("WithRow() must be called with the same number of args as the preceding AndReturnColumns() call") + } + r.results = append(r.results, values) + return r +} + +type rows struct { + r Rows + closed bool +} + +// Columns implements the [driver.Rows] interface. +func (r *rows) Columns() []string { + return r.r.columns +} + +// Close implements the [driver.Rows] interface. +func (r *rows) Close() error { + r.closed = true + return nil +} + +// Next implements the [driver.Rows] interface. +func (r *rows) Next(dest []driver.Value) error { + if r.closed { + return errors.New("rows object was closed") + } + if len(r.r.results) == 0 { + return io.EOF + } + for idx, value := range r.r.results[0] { + dest[idx] = value + } + r.r.results = r.r.results[1:] + return nil +} diff --git a/internal/plan.go b/internal/plan.go deleted file mode 100644 index b57b8dd..0000000 --- a/internal/plan.go +++ /dev/null @@ -1,302 +0,0 @@ -// SPDX-FileCopyrightText: 2026 Stefan Majewsky -// SPDX-License-Identifier: Apache-2.0 - -package internal - -import ( - "errors" - "fmt" - "reflect" - "slices" - "strings" -) - -// Plan holds all information that we can derive from reflecting on a given type. -// The queries held within are only valid within the context of a given SQL dialect. -type Plan struct { - TypeName string // for use in error messages - TableName string // from info.TableNameIs marker (if any) - AllColumnNames []string // in order of struct fields - PrimaryKeyColumnNames []string // from info.PrimaryKeyIs marker (if any) - AutoColumnNames []string // subset of AllColumnNames where field has `,auto` marker - - // Argument for reflect.Value.FieldByIndex() for each column name. - IndexByColumnName map[string][]int - - // In dialects with UsesLastInsertID() == true, whether the ID column must be written with reflect.Value.SetInt() or reflect.Value.SetUint(). - FillIDWithSetUint bool - FillIDWithSetInt bool - - // Planned queries. - Select PlannedQuery // only `SELECT ... FROM ... WHERE `; user supplies the rest during Select{,One}Where() - Insert PlannedQuery - Update PlannedQuery - Delete PlannedQuery -} - -// PlannedQuery appears in type Plan. -type PlannedQuery struct { - // Empty if the respective query type is not supported by this Plan for lack of the required marker types. - Query string - // Arguments for reflect.Value.FieldByIndex() in the correct order for the query arguments of the above query. - ArgumentIndexes [][]int - // Arguments for reflect.Value.FieldByIndex() in the correct order for the Scan() arguments of the above query. - ScanIndexes [][]int -} - -// PlanOpts holds additional arguments to BuildPlan(). -type PlanOpts struct { - TableName string - PrimaryKeyColumnNames []string -} - -// BuildPlan creates a new plan for the given struct type. -func BuildPlan(t reflect.Type, dialect Dialect, opts PlanOpts) (Plan, error) { - p, err := buildPlan(t, dialect, opts) - if err != nil { - return Plan{}, fmt.Errorf("cannot use type %s.%s for queries: %w", t.PkgPath(), t.Name(), err) - } - return p, nil -} - -func buildPlan(t reflect.Type, dialect Dialect, opts PlanOpts) (Plan, error) { - if t.Kind() != reflect.Struct { - return Plan{}, fmt.Errorf("expected struct type, but got kind %s", t.Kind().String()) - } - - var p = Plan{ - TypeName: t.Name(), - TableName: opts.TableName, - PrimaryKeyColumnNames: opts.PrimaryKeyColumnNames, - IndexByColumnName: make(map[string][]int), - } - - // discover addressable fields in this type, - // collect information from markers and tags - for _, field := range reflect.VisibleFields(t) { - tags := strings.Split(strings.TrimSpace(field.Tag.Get("db")), ",") - - switch { - case field.PkgPath != "": - // ignore unexported fields (otherwise reflect.Value.Interface() on the field would panic) - continue - case field.Anonymous && field.Type.Kind() == reflect.Struct: - // for embedded struct fields, only consider their members, not the type itself, as a potential column - continue - default: - columnName, extraTags := tags[0], tags[1:] - if columnName == "-" { - continue - } - if columnName == "" { - columnName = field.Name - } - if otherIndex := p.IndexByColumnName[columnName]; otherIndex != nil { - return Plan{}, fmt.Errorf( - "duplicate tag `db:%q` on field index %v, but also on field index %v", - columnName, otherIndex, field.Index, - ) - } - p.IndexByColumnName[columnName] = field.Index - p.AllColumnNames = append(p.AllColumnNames, columnName) - - for _, tag := range extraTags { - switch tag { - case "auto": - p.AutoColumnNames = append(p.AutoColumnNames, columnName) - default: - return Plan{}, fmt.Errorf("unknown tag `db:%q` on field index %v", ","+tag, field.Index) - } - } - } - } - - // validation: defining a primary key only makes sense for records that map onto a single table - if len(p.PrimaryKeyColumnNames) > 0 && p.TableName == "" { - return Plan{}, errors.New("cannot declare a primary key without also providing the TableNameIs option") - } - - // validation: oblast.PrimaryKeyInfo must refer to columns that exist - for _, columnName := range p.PrimaryKeyColumnNames { - _, ok := p.IndexByColumnName[columnName] - if !ok { - return Plan{}, fmt.Errorf("no field has tag `db:%q`, but a field of this name was declared in the primary key", columnName) - } - } - - // validation: LastInsertID() only works if at most one column is auto-filled, and if that column holds an integer type - if dialect.UsesLastInsertID() { - switch len(p.AutoColumnNames) { - case 0: - // nothing to check - case 1: - columnName := p.AutoColumnNames[0] - field := t.FieldByIndex(p.IndexByColumnName[columnName]) - switch field.Type.Kind() { //nolint:exhaustive // false positive - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - p.FillIDWithSetInt = true - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - p.FillIDWithSetUint = true - default: - return Plan{}, fmt.Errorf( - "column is marked as auto-filled (%s), but this SQL dialect only supports auto-filling struct fields with integer types", - strings.Join(p.AutoColumnNames, ", "), - ) - } - default: - return Plan{}, fmt.Errorf( - "multiple columns are marked as auto-filled (%s), but this SQL dialect only supports at most one per table", - strings.Join(p.AutoColumnNames, ", "), - ) - } - } - - // prepare query strings - p.Select = p.buildSelectQueryIfPossible(dialect) - p.Insert = p.buildInsertQueryIfPossible(dialect) - p.Update = p.buildUpdateQueryIfPossible(dialect) - p.Delete = p.buildDeleteQueryIfPossible(dialect) - - return p, nil -} - -func (p Plan) getNonAutoColumnNames() []string { - result := make([]string, 0, len(p.AllColumnNames)-len(p.AutoColumnNames)) - for _, columnName := range p.AllColumnNames { - if !slices.Contains(p.AutoColumnNames, columnName) { - result = append(result, columnName) - } - } - return result -} - -func (p Plan) getNonPrimaryKeyColumnNames() []string { - result := make([]string, 0, len(p.AllColumnNames)-len(p.PrimaryKeyColumnNames)) - for _, columnName := range p.AllColumnNames { - if !slices.Contains(p.PrimaryKeyColumnNames, columnName) { - result = append(result, columnName) - } - } - return result -} - -func (p Plan) buildSelectQueryIfPossible(dialect Dialect) PlannedQuery { - if p.TableName == "" { - return PlannedQuery{Query: ""} - } - - var ( - scanIndexes = make([][]int, len(p.AllColumnNames)) - quotedColumnNames = make([]string, len(p.AllColumnNames)) - ) - for idx, columnName := range p.AllColumnNames { - scanIndexes[idx] = p.IndexByColumnName[columnName] - quotedColumnNames[idx] = dialect.QuoteIdentifier(columnName) - } - - query := fmt.Sprintf( - `SELECT %s FROM %s WHERE `, - strings.Join(quotedColumnNames, ", "), - dialect.QuoteIdentifier(p.TableName), - ) - return PlannedQuery{query, nil, scanIndexes} -} - -func (p Plan) buildInsertQueryIfPossible(dialect Dialect) PlannedQuery { - if p.TableName == "" || len(p.AllColumnNames) == 0 { - return PlannedQuery{Query: ""} - } - nonAutoColumnNames := p.getNonAutoColumnNames() - if len(nonAutoColumnNames) == 0 { - return PlannedQuery{Query: ""} - } - - var ( - argumentIndexes = make([][]int, len(nonAutoColumnNames)) - scanIndexes [][]int - quotedColumnNames = make([]string, len(nonAutoColumnNames)) - quotedPlaceholders = make([]string, len(nonAutoColumnNames)) - ) - for idx, columnName := range nonAutoColumnNames { - argumentIndexes[idx] = p.IndexByColumnName[columnName] - quotedColumnNames[idx] = dialect.QuoteIdentifier(columnName) - quotedPlaceholders[idx] = dialect.Placeholder(idx) - } - if len(p.AutoColumnNames) > 0 { - // NOTE: This is filled even if dialect.UsesLastInsertID() is false. - // We need this index to find the right value on which to run SetInt() or SetUint(). - scanIndexes = make([][]int, len(p.AutoColumnNames)) - for idx, columnName := range p.AutoColumnNames { - scanIndexes[idx] = p.IndexByColumnName[columnName] - } - } - - query := fmt.Sprintf( - `INSERT INTO %s (%s) VALUES (%s)`, - dialect.QuoteIdentifier(p.TableName), - strings.Join(quotedColumnNames, ", "), - strings.Join(quotedPlaceholders, ", "), - ) - if len(p.AutoColumnNames) > 0 { - query += dialect.InsertSuffixForAutoColumns(p.AutoColumnNames) - } - return PlannedQuery{query, argumentIndexes, scanIndexes} -} - -func (p Plan) buildUpdateQueryIfPossible(dialect Dialect) PlannedQuery { - if p.TableName == "" || len(p.PrimaryKeyColumnNames) == 0 { - return PlannedQuery{Query: ""} - } - nonPrimaryKeyColumnNames := p.getNonPrimaryKeyColumnNames() - if len(nonPrimaryKeyColumnNames) == 0 { - return PlannedQuery{Query: ""} - } - - var ( - setArgumentIndexes = make([][]int, len(nonPrimaryKeyColumnNames)) - setClauses = make([]string, len(nonPrimaryKeyColumnNames)) - ) - for idx, columnName := range nonPrimaryKeyColumnNames { - setArgumentIndexes[idx] = p.IndexByColumnName[columnName] - setClauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx)) - } - - var ( - whereArgumentIndexes = make([][]int, len(p.PrimaryKeyColumnNames)) - whereClauses = make([]string, len(p.PrimaryKeyColumnNames)) - ) - for idx, columnName := range p.PrimaryKeyColumnNames { - whereArgumentIndexes[idx] = p.IndexByColumnName[columnName] - whereClauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx+len(setClauses))) - } - - query := fmt.Sprintf( - `UPDATE %s SET %s WHERE %s`, - dialect.QuoteIdentifier(p.TableName), - strings.Join(setClauses, ", "), - strings.Join(whereClauses, " AND "), - ) - return PlannedQuery{query, slices.Concat(setArgumentIndexes, whereArgumentIndexes), nil} -} - -func (p Plan) buildDeleteQueryIfPossible(dialect Dialect) PlannedQuery { - if p.TableName == "" || len(p.PrimaryKeyColumnNames) == 0 { - return PlannedQuery{Query: ""} - } - - var ( - argumentIndexes = make([][]int, len(p.PrimaryKeyColumnNames)) - clauses = make([]string, len(p.PrimaryKeyColumnNames)) - ) - for idx, columnName := range p.PrimaryKeyColumnNames { - argumentIndexes[idx] = p.IndexByColumnName[columnName] - clauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx)) - } - - query := fmt.Sprintf( - `DELETE FROM %s WHERE %s`, - dialect.QuoteIdentifier(p.TableName), - strings.Join(clauses, " AND "), - ) - return PlannedQuery{query, argumentIndexes, nil} -} diff --git a/internal/plan_test.go b/internal/plan_test.go deleted file mode 100644 index e692556..0000000 --- a/internal/plan_test.go +++ /dev/null @@ -1,277 +0,0 @@ -// SPDX-FileCopyrightText: 2026 Stefan Majewsky -// SPDX-License-Identifier: Apache-2.0 - -package internal_test - -import ( - "reflect" - "testing" - "time" - - "go.xyrillian.de/oblast/internal" - "go.xyrillian.de/oblast/internal/assert" -) - -func TestPlanFieldTraversal(t *testing.T) { - type Timestamps struct { - CreatedAt time.Time `db:"created_at"` - UpdatedAt *time.Time `db:"updated_at"` - } - type yetMoreTimestamps struct { - DeletedAt *time.Time `db:"deleted_at"` - } - type Log struct { - ID int64 `db:"id,auto"` - Message string - private1 bool `db:"private1"` //nolint:unused - Ignored any `db:"-"` - Timestamps - yetMoreTimestamps - } - - // check that the plan for Log: - // 1. has no IndexByColumnName entries for marker types - // 2. uses the field name as a column name for "Message" - // 3. ignores "private1" because it cannot be written through reflection - // 4. ignores "Ignored" because its column name is "-" - // 5. traverses into "Timestamps" and includes its fields as well - // 6. traverses into "yetMoreTimestamps" as well (despite the extra pointer and the type being private) - // 7. recognizes "id" as an autofilled column - plan, err := internal.BuildPlan(reflect.TypeFor[Log](), internal.PostgresDialect{}, internal.PlanOpts{ - TableName: "log_entries", - PrimaryKeyColumnNames: []string{"id"}, - }) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.TableName, "log_entries") - assert.DeepEqual(t, plan.AllColumnNames, []string{"id", "Message", "created_at", "updated_at", "deleted_at"}) - assert.DeepEqual(t, plan.PrimaryKeyColumnNames, []string{"id"}) - assert.DeepEqual(t, plan.AutoColumnNames, []string{"id"}) - assert.DeepEqual(t, plan.IndexByColumnName, map[string][]int{ - "id": {0}, - "Message": {1}, - "created_at": {4, 0}, - "updated_at": {4, 1}, - "deleted_at": {5, 0}, - }) -} - -// TODO: test that, during Select(), assignment into embedded fields with pointer-to-struct type works (docs say that this might panic if we do not allocate into the pointer first) - -func TestQueryConstructionBasic(t *testing.T) { - type record struct { - ID int64 `db:",auto"` - Description string - CreatedAt time.Time - } - opts := internal.PlanOpts{ - TableName: "basic_records", - PrimaryKeyColumnNames: []string{"ID"}, - } - - t.Run("PostgresDialect", func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[record](), internal.PostgresDialect{}, opts) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES ($1, $2) RETURNING "ID"`) - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) - assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}}) - assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = $1, "CreatedAt" = $2 WHERE "ID" = $3`) - assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, `DELETE FROM "basic_records" WHERE "ID" = $1`) - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}}) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - }) - - t.Run("SqliteDialect", func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[record](), internal.SqliteDialect{}, opts) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES (?, ?)`) - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) - assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}}) - assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = ?, "CreatedAt" = ? WHERE "ID" = ?`) - assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, `DELETE FROM "basic_records" WHERE "ID" = ?`) - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}}) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - }) -} - -func TestQueryConstructionWithoutPrimaryKey(t *testing.T) { - type relation struct { - FooID int64 `db:"foo_id"` - BarID int64 `db:"bar_id"` - } - opts := internal.PlanOpts{ - TableName: "foo_bar_relations", - } - - t.Run("PostgresDialect", func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[relation](), internal.PostgresDialect{}, opts) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES ($1, $2)`) - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) - assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) - assert.Equal(t, plan.Update.Query, "") - assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, "") - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - }) - - t.Run("SqliteDialect", func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[relation](), internal.SqliteDialect{}, opts) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?)`) - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) - assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) - assert.Equal(t, plan.Update.Query, "") - assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, "") - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - }) -} - -func TestQueryConstructionImpossble(t *testing.T) { - type unstructuredData struct { - Foo int - Bar string - } - opts := internal.PlanOpts{} - - testWith := func(dialect internal.Dialect) func(*testing.T) { - return func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[unstructuredData](), dialect, opts) - if err != nil { - t.Error(err) - } - - assert.Equal(t, plan.Select.Query, "") - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, nil) - assert.Equal(t, plan.Insert.Query, "") - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) - assert.Equal(t, plan.Update.Query, "") - assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, "") - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - } - } - - t.Run("PostgresDialect", testWith(internal.PostgresDialect{})) - t.Run("SqliteDialect", testWith(internal.SqliteDialect{})) -} - -func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) { - type record struct { - GroupID int64 `db:"group_id"` - Name string `db:"name"` - CreatedAt time.Time `db:"created_at"` - } - opts := internal.PlanOpts{ - TableName: "complex_records", - PrimaryKeyColumnNames: []string{"group_id", "name"}, - } - - t.Run("PostgresDialect", func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[record](), internal.PostgresDialect{}, opts) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES ($1, $2, $3)`) - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) - assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) - assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = $1 WHERE "group_id" = $2 AND "name" = $3`) - assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}}) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, `DELETE FROM "complex_records" WHERE "group_id" = $1 AND "name" = $2`) - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}, {1}}) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - }) - - t.Run("SqliteDialect", func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[record](), internal.SqliteDialect{}, opts) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES (?, ?, ?)`) - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) - assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) - assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = ? WHERE "group_id" = ? AND "name" = ?`) - assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}}) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, `DELETE FROM "complex_records" WHERE "group_id" = ? AND "name" = ?`) - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}, {1}}) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - }) -} - -func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) { - type record struct { - ID int64 `db:"id,auto"` - Name string `db:"name"` - CreatedAt time.Time `db:"created_at,auto"` - } - opts := internal.PlanOpts{ - TableName: "autogenerated_records", - PrimaryKeyColumnNames: []string{"id"}, - } - - t.Run("PostgresDialect", func(t *testing.T) { - plan, err := internal.BuildPlan(reflect.TypeFor[record](), internal.PostgresDialect{}, opts) - if err != nil { - t.Error(err) - } - assert.Equal(t, plan.Select.Query, `SELECT "id", "name", "created_at" FROM "autogenerated_records" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) - assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "autogenerated_records" ("name") VALUES ($1) RETURNING "id", "created_at"`) - assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}}) - assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}, {2}}) - assert.Equal(t, plan.Update.Query, `UPDATE "autogenerated_records" SET "name" = $1, "created_at" = $2 WHERE "id" = $3`) - assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) - assert.DeepEqual(t, plan.Update.ScanIndexes, nil) - assert.Equal(t, plan.Delete.Query, `DELETE FROM "autogenerated_records" WHERE "id" = $1`) - assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}}) - assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) - }) - - t.Run("SqliteDialect", func(t *testing.T) { - _, err := internal.BuildPlan(reflect.TypeFor[record](), internal.SqliteDialect{}, opts) - assert.Equal(t, err.Error(), `cannot use type go.xyrillian.de/oblast/internal_test.record for queries: multiple columns are marked as auto-filled (id, created_at), but this SQL dialect only supports at most one per table`) - }) -} diff --git a/oblast.go b/oblast.go index 15f840a..7b40146 100644 --- a/oblast.go +++ b/oblast.go @@ -42,24 +42,23 @@ package oblast // import "go.xyrillian.de/oblast" import ( "database/sql" + "fmt" "reflect" - - "go.xyrillian.de/oblast/internal" ) // PlanOption is an option that can be given to NewStore() to influence query planning for a certain type of record. -type PlanOption func(*internal.PlanOpts) +type PlanOption func(*planOpts) // TableNameIs is a PlanOption for record types that correspond to exactly one database table (as opposed to a join of multiple tables). // This option is required to enable any of the methods of [Store] that use partially or fully auto-generated query strings. func TableNameIs(name string) PlanOption { - return func(opts *internal.PlanOpts) { opts.TableName = name } + return func(opts *planOpts) { opts.TableName = name } } // PrimaryKeyIs is a PlanOption for record types that correspond to a database table with a primary key. // This option is required to enable use of the [Store.Update] and [Store.Delete] methods. func PrimaryKeyIs(columnNames ...string) PlanOption { - return func(opts *internal.PlanOpts) { opts.PrimaryKeyColumnNames = columnNames } + return func(opts *planOpts) { opts.PrimaryKeyColumnNames = columnNames } } // Handle is an interface for functions providing direct DB access. @@ -83,7 +82,7 @@ var ( // and can also be used to execute autogenerated queries if the respective [PlanOption] values were provided during [NewStore]. type Store[R any] struct { dialect Dialect - plan internal.Plan + plan plan } // NewStore initializes a store for record type R. @@ -110,11 +109,15 @@ type Store[R any] struct { // Besides the declaration of a column name, the following extra tags are understood (as a comma-separated list following the column name): // - "auto": During [Store.Insert], do not store this field's value. Instead, the database will auto-generate a value, which will be read back into the record. func NewStore[R any](dialect Dialect, opts ...PlanOption) (Store[R], error) { - var popts internal.PlanOpts + var popts planOpts for _, opt := range opts { opt(&popts) } - plan, err := internal.BuildPlan(reflect.TypeFor[R](), dialect, popts) + plan, err := buildPlan(reflect.TypeFor[R](), dialect, popts) + if err != nil { + var zero R + return Store[R]{}, fmt.Errorf("cannot use type %T for queries: %w", zero, err) + } return Store[R]{dialect, plan}, err } diff --git a/plan.go b/plan.go new file mode 100644 index 0000000..da9f9b5 --- /dev/null +++ b/plan.go @@ -0,0 +1,294 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky +// SPDX-License-Identifier: Apache-2.0 + +package oblast + +import ( + "errors" + "fmt" + "reflect" + "slices" + "strings" +) + +// plan holds all information that we can derive from reflecting on a given type. +// The queries held within are only valid within the context of a given SQL dialect. +type plan struct { + TypeName string // for use in error messages + TableName string // from info.TableNameIs marker (if any) + AllColumnNames []string // in order of struct fields + PrimaryKeyColumnNames []string // from info.PrimaryKeyIs marker (if any) + AutoColumnNames []string // subset of AllColumnNames where field has `,auto` marker + + // Argument for reflect.Value.FieldByIndex() for each column name. + IndexByColumnName map[string][]int + + // In dialects with UsesLastInsertID() == true, whether the ID column must be written with reflect.Value.SetInt() or reflect.Value.SetUint(). + FillIDWithSetUint bool + FillIDWithSetInt bool + + // Planned queries. + Select plannedQuery // only `SELECT ... FROM ... WHERE `; user supplies the rest during Select{,One}Where() + Insert plannedQuery + Update plannedQuery + Delete plannedQuery +} + +// plannedQuery appears in type plan. +type plannedQuery struct { + // Empty if the respective query type is not supported by this plan for lack of the required marker types. + Query string + // Arguments for reflect.Value.FieldByIndex() in the correct order for the query arguments of the above query. + ArgumentIndexes [][]int + // Arguments for reflect.Value.FieldByIndex() in the correct order for the Scan() arguments of the above query. + ScanIndexes [][]int +} + +// planOpts holds additional arguments to buildPlan(). +type planOpts struct { + TableName string + PrimaryKeyColumnNames []string +} + +// buildPlan creates a new plan for the given struct type. +func buildPlan(t reflect.Type, dialect Dialect, opts planOpts) (plan, error) { + if t.Kind() != reflect.Struct { + return plan{}, fmt.Errorf("expected struct type, but got kind %s", t.Kind().String()) + } + + var p = plan{ + TypeName: t.Name(), + TableName: opts.TableName, + PrimaryKeyColumnNames: opts.PrimaryKeyColumnNames, + IndexByColumnName: make(map[string][]int), + } + + // discover addressable fields in this type, + // collect information from markers and tags + for _, field := range reflect.VisibleFields(t) { + tags := strings.Split(strings.TrimSpace(field.Tag.Get("db")), ",") + + switch { + case field.PkgPath != "": + // ignore unexported fields (otherwise reflect.Value.Interface() on the field would panic) + continue + case field.Anonymous && field.Type.Kind() == reflect.Struct: + // for embedded struct fields, only consider their members, not the type itself, as a potential column + continue + default: + columnName, extraTags := tags[0], tags[1:] + if columnName == "-" { + continue + } + if columnName == "" { + columnName = field.Name + } + if otherIndex := p.IndexByColumnName[columnName]; otherIndex != nil { + return plan{}, fmt.Errorf( + "duplicate tag `db:%q` on field index %v, but also on field index %v", + columnName, otherIndex, field.Index, + ) + } + p.IndexByColumnName[columnName] = field.Index + p.AllColumnNames = append(p.AllColumnNames, columnName) + + for _, tag := range extraTags { + switch tag { + case "auto": + p.AutoColumnNames = append(p.AutoColumnNames, columnName) + default: + return plan{}, fmt.Errorf("unknown tag `db:%q` on field index %v", ","+tag, field.Index) + } + } + } + } + + // validation: defining a primary key only makes sense for records that map onto a single table + if len(p.PrimaryKeyColumnNames) > 0 && p.TableName == "" { + return plan{}, errors.New("cannot declare a primary key without also providing the TableNameIs option") + } + + // validation: oblast.PrimaryKeyInfo must refer to columns that exist + for _, columnName := range p.PrimaryKeyColumnNames { + _, ok := p.IndexByColumnName[columnName] + if !ok { + return plan{}, fmt.Errorf("no field has tag `db:%q`, but a field of this name was declared in the primary key", columnName) + } + } + + // validation: LastInsertID() only works if at most one column is auto-filled, and if that column holds an integer type + if dialect.UsesLastInsertID() { + switch len(p.AutoColumnNames) { + case 0: + // nothing to check + case 1: + columnName := p.AutoColumnNames[0] + field := t.FieldByIndex(p.IndexByColumnName[columnName]) + switch field.Type.Kind() { //nolint:exhaustive // false positive + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + p.FillIDWithSetInt = true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + p.FillIDWithSetUint = true + default: + return plan{}, fmt.Errorf( + "column is marked as auto-filled (%s), but this SQL dialect only supports auto-filling struct fields with integer types", + strings.Join(p.AutoColumnNames, ", "), + ) + } + default: + return plan{}, fmt.Errorf( + "multiple columns are marked as auto-filled (%s), but this SQL dialect only supports at most one per table", + strings.Join(p.AutoColumnNames, ", "), + ) + } + } + + // prepare query strings + p.Select = p.buildSelectQueryIfPossible(dialect) + p.Insert = p.buildInsertQueryIfPossible(dialect) + p.Update = p.buildUpdateQueryIfPossible(dialect) + p.Delete = p.buildDeleteQueryIfPossible(dialect) + + return p, nil +} + +func (p plan) getNonAutoColumnNames() []string { + result := make([]string, 0, len(p.AllColumnNames)-len(p.AutoColumnNames)) + for _, columnName := range p.AllColumnNames { + if !slices.Contains(p.AutoColumnNames, columnName) { + result = append(result, columnName) + } + } + return result +} + +func (p plan) getNonPrimaryKeyColumnNames() []string { + result := make([]string, 0, len(p.AllColumnNames)-len(p.PrimaryKeyColumnNames)) + for _, columnName := range p.AllColumnNames { + if !slices.Contains(p.PrimaryKeyColumnNames, columnName) { + result = append(result, columnName) + } + } + return result +} + +func (p plan) buildSelectQueryIfPossible(dialect Dialect) plannedQuery { + if p.TableName == "" { + return plannedQuery{Query: ""} + } + + var ( + scanIndexes = make([][]int, len(p.AllColumnNames)) + quotedColumnNames = make([]string, len(p.AllColumnNames)) + ) + for idx, columnName := range p.AllColumnNames { + scanIndexes[idx] = p.IndexByColumnName[columnName] + quotedColumnNames[idx] = dialect.QuoteIdentifier(columnName) + } + + query := fmt.Sprintf( + `SELECT %s FROM %s WHERE `, + strings.Join(quotedColumnNames, ", "), + dialect.QuoteIdentifier(p.TableName), + ) + return plannedQuery{query, nil, scanIndexes} +} + +func (p plan) buildInsertQueryIfPossible(dialect Dialect) plannedQuery { + if p.TableName == "" || len(p.AllColumnNames) == 0 { + return plannedQuery{Query: ""} + } + nonAutoColumnNames := p.getNonAutoColumnNames() + if len(nonAutoColumnNames) == 0 { + return plannedQuery{Query: ""} + } + + var ( + argumentIndexes = make([][]int, len(nonAutoColumnNames)) + scanIndexes [][]int + quotedColumnNames = make([]string, len(nonAutoColumnNames)) + quotedPlaceholders = make([]string, len(nonAutoColumnNames)) + ) + for idx, columnName := range nonAutoColumnNames { + argumentIndexes[idx] = p.IndexByColumnName[columnName] + quotedColumnNames[idx] = dialect.QuoteIdentifier(columnName) + quotedPlaceholders[idx] = dialect.Placeholder(idx) + } + if len(p.AutoColumnNames) > 0 { + // NOTE: This is filled even if dialect.UsesLastInsertID() is false. + // We need this index to find the right value on which to run SetInt() or SetUint(). + scanIndexes = make([][]int, len(p.AutoColumnNames)) + for idx, columnName := range p.AutoColumnNames { + scanIndexes[idx] = p.IndexByColumnName[columnName] + } + } + + query := fmt.Sprintf( + `INSERT INTO %s (%s) VALUES (%s)`, + dialect.QuoteIdentifier(p.TableName), + strings.Join(quotedColumnNames, ", "), + strings.Join(quotedPlaceholders, ", "), + ) + if len(p.AutoColumnNames) > 0 { + query += dialect.InsertSuffixForAutoColumns(p.AutoColumnNames) + } + return plannedQuery{query, argumentIndexes, scanIndexes} +} + +func (p plan) buildUpdateQueryIfPossible(dialect Dialect) plannedQuery { + if p.TableName == "" || len(p.PrimaryKeyColumnNames) == 0 { + return plannedQuery{Query: ""} + } + nonPrimaryKeyColumnNames := p.getNonPrimaryKeyColumnNames() + if len(nonPrimaryKeyColumnNames) == 0 { + return plannedQuery{Query: ""} + } + + var ( + setArgumentIndexes = make([][]int, len(nonPrimaryKeyColumnNames)) + setClauses = make([]string, len(nonPrimaryKeyColumnNames)) + ) + for idx, columnName := range nonPrimaryKeyColumnNames { + setArgumentIndexes[idx] = p.IndexByColumnName[columnName] + setClauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx)) + } + + var ( + whereArgumentIndexes = make([][]int, len(p.PrimaryKeyColumnNames)) + whereClauses = make([]string, len(p.PrimaryKeyColumnNames)) + ) + for idx, columnName := range p.PrimaryKeyColumnNames { + whereArgumentIndexes[idx] = p.IndexByColumnName[columnName] + whereClauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx+len(setClauses))) + } + + query := fmt.Sprintf( + `UPDATE %s SET %s WHERE %s`, + dialect.QuoteIdentifier(p.TableName), + strings.Join(setClauses, ", "), + strings.Join(whereClauses, " AND "), + ) + return plannedQuery{query, slices.Concat(setArgumentIndexes, whereArgumentIndexes), nil} +} + +func (p plan) buildDeleteQueryIfPossible(dialect Dialect) plannedQuery { + if p.TableName == "" || len(p.PrimaryKeyColumnNames) == 0 { + return plannedQuery{Query: ""} + } + + var ( + argumentIndexes = make([][]int, len(p.PrimaryKeyColumnNames)) + clauses = make([]string, len(p.PrimaryKeyColumnNames)) + ) + for idx, columnName := range p.PrimaryKeyColumnNames { + argumentIndexes[idx] = p.IndexByColumnName[columnName] + clauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx)) + } + + query := fmt.Sprintf( + `DELETE FROM %s WHERE %s`, + dialect.QuoteIdentifier(p.TableName), + strings.Join(clauses, " AND "), + ) + return plannedQuery{query, argumentIndexes, nil} +} diff --git a/plan_test.go b/plan_test.go new file mode 100644 index 0000000..1095016 --- /dev/null +++ b/plan_test.go @@ -0,0 +1,278 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky +// SPDX-License-Identifier: Apache-2.0 + +package oblast + +// ^ NOTE: This is testing internal types and thus must reside in the same package. + +import ( + "reflect" + "testing" + "time" + + "go.xyrillian.de/oblast/internal/assert" +) + +func TestPlanFieldTraversal(t *testing.T) { + type Timestamps struct { + CreatedAt time.Time `db:"created_at"` + UpdatedAt *time.Time `db:"updated_at"` + } + type yetMoreTimestamps struct { + DeletedAt *time.Time `db:"deleted_at"` + } + type Log struct { + ID int64 `db:"id,auto"` + Message string + private1 bool `db:"private1"` //nolint:unused + Ignored any `db:"-"` + Timestamps + yetMoreTimestamps + } + + // check that the plan for Log: + // 1. has no IndexByColumnName entries for marker types + // 2. uses the field name as a column name for "Message" + // 3. ignores "private1" because it cannot be written through reflection + // 4. ignores "Ignored" because its column name is "-" + // 5. traverses into "Timestamps" and includes its fields as well + // 6. traverses into "yetMoreTimestamps" as well (despite the extra pointer and the type being private) + // 7. recognizes "id" as an autofilled column + plan, err := buildPlan(reflect.TypeFor[Log](), PostgresDialect(), planOpts{ + TableName: "log_entries", + PrimaryKeyColumnNames: []string{"id"}, + }) + if err != nil { + t.Error(err) + } + assert.Equal(t, plan.TableName, "log_entries") + assert.DeepEqual(t, plan.AllColumnNames, []string{"id", "Message", "created_at", "updated_at", "deleted_at"}) + assert.DeepEqual(t, plan.PrimaryKeyColumnNames, []string{"id"}) + assert.DeepEqual(t, plan.AutoColumnNames, []string{"id"}) + assert.DeepEqual(t, plan.IndexByColumnName, map[string][]int{ + "id": {0}, + "Message": {1}, + "created_at": {4, 0}, + "updated_at": {4, 1}, + "deleted_at": {5, 0}, + }) +} + +// TODO: test that, during Select(), assignment into embedded fields with pointer-to-struct type works (docs say that this might panic if we do not allocate into the pointer first) + +func TestQueryConstructionBasic(t *testing.T) { + type record struct { + ID int64 `db:",auto"` + Description string + CreatedAt time.Time + } + opts := planOpts{ + TableName: "basic_records", + PrimaryKeyColumnNames: []string{"ID"}, + } + + t.Run("PostgresDialect", func(t *testing.T) { + plan, err := buildPlan(reflect.TypeFor[record](), PostgresDialect(), opts) + if err != nil { + t.Error(err) + } + assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) + assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES ($1, $2) RETURNING "ID"`) + assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}}) + assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = $1, "CreatedAt" = $2 WHERE "ID" = $3`) + assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) + assert.Equal(t, plan.Delete.Query, `DELETE FROM "basic_records" WHERE "ID" = $1`) + assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}}) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) + }) + + t.Run("SqliteDialect", func(t *testing.T) { + plan, err := buildPlan(reflect.TypeFor[record](), SqliteDialect(), opts) + if err != nil { + t.Error(err) + } + assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) + assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES (?, ?)`) + assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}}) + assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = ?, "CreatedAt" = ? WHERE "ID" = ?`) + assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) + assert.Equal(t, plan.Delete.Query, `DELETE FROM "basic_records" WHERE "ID" = ?`) + assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}}) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) + }) +} + +func TestQueryConstructionWithoutPrimaryKey(t *testing.T) { + type relation struct { + FooID int64 `db:"foo_id"` + BarID int64 `db:"bar_id"` + } + opts := planOpts{ + TableName: "foo_bar_relations", + } + + t.Run("PostgresDialect", func(t *testing.T) { + plan, err := buildPlan(reflect.TypeFor[relation](), PostgresDialect(), opts) + if err != nil { + t.Error(err) + } + assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}}) + assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES ($1, $2)`) + assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) + assert.Equal(t, plan.Update.Query, "") + assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) + assert.Equal(t, plan.Delete.Query, "") + assert.DeepEqual(t, plan.Delete.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) + }) + + t.Run("SqliteDialect", func(t *testing.T) { + plan, err := buildPlan(reflect.TypeFor[relation](), SqliteDialect(), opts) + if err != nil { + t.Error(err) + } + assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}}) + assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?)`) + assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) + assert.Equal(t, plan.Update.Query, "") + assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) + assert.Equal(t, plan.Delete.Query, "") + assert.DeepEqual(t, plan.Delete.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) + }) +} + +func TestQueryConstructionImpossble(t *testing.T) { + type unstructuredData struct { + Foo int + Bar string + } + opts := planOpts{} + + testWith := func(dialect Dialect) func(*testing.T) { + return func(t *testing.T) { + plan, err := buildPlan(reflect.TypeFor[unstructuredData](), dialect, opts) + if err != nil { + t.Error(err) + } + + assert.Equal(t, plan.Select.Query, "") + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, nil) + assert.Equal(t, plan.Insert.Query, "") + assert.DeepEqual(t, plan.Insert.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) + assert.Equal(t, plan.Update.Query, "") + assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) + assert.Equal(t, plan.Delete.Query, "") + assert.DeepEqual(t, plan.Delete.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) + } + } + + t.Run("PostgresDialect", testWith(PostgresDialect())) + t.Run("SqliteDialect", testWith(SqliteDialect())) +} + +func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) { + type record struct { + GroupID int64 `db:"group_id"` + Name string `db:"name"` + CreatedAt time.Time `db:"created_at"` + } + opts := planOpts{ + TableName: "complex_records", + PrimaryKeyColumnNames: []string{"group_id", "name"}, + } + + t.Run("PostgresDialect", func(t *testing.T) { + plan, err := buildPlan(reflect.TypeFor[record](), PostgresDialect(), opts) + if err != nil { + t.Error(err) + } + assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) + assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES ($1, $2, $3)`) + assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) + assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = $1 WHERE "group_id" = $2 AND "name" = $3`) + assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}}) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) + assert.Equal(t, plan.Delete.Query, `DELETE FROM "complex_records" WHERE "group_id" = $1 AND "name" = $2`) + assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}, {1}}) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) + }) + + t.Run("SqliteDialect", func(t *testing.T) { + plan, err := buildPlan(reflect.TypeFor[record](), SqliteDialect(), opts) + if err != nil { + t.Error(err) + } + assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) + assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES (?, ?, ?)`) + assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) + assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = ? WHERE "group_id" = ? AND "name" = ?`) + assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}}) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) + assert.Equal(t, plan.Delete.Query, `DELETE FROM "complex_records" WHERE "group_id" = ? AND "name" = ?`) + assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}, {1}}) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) + }) +} + +func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) { + type record struct { + ID int64 `db:"id,auto"` + Name string `db:"name"` + CreatedAt time.Time `db:"created_at,auto"` + } + opts := planOpts{ + TableName: "autogenerated_records", + PrimaryKeyColumnNames: []string{"id"}, + } + + t.Run("PostgresDialect", func(t *testing.T) { + plan, err := buildPlan(reflect.TypeFor[record](), PostgresDialect(), opts) + if err != nil { + t.Error(err) + } + assert.Equal(t, plan.Select.Query, `SELECT "id", "name", "created_at" FROM "autogenerated_records" WHERE `) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) + assert.Equal(t, plan.Insert.Query, `INSERT INTO "autogenerated_records" ("name") VALUES ($1) RETURNING "id", "created_at"`) + assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}, {2}}) + assert.Equal(t, plan.Update.Query, `UPDATE "autogenerated_records" SET "name" = $1, "created_at" = $2 WHERE "id" = $3`) + assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) + assert.Equal(t, plan.Delete.Query, `DELETE FROM "autogenerated_records" WHERE "id" = $1`) + assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}}) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) + }) + + t.Run("SqliteDialect", func(t *testing.T) { + _, err := NewStore[record](SqliteDialect()) + assert.Equal(t, err.Error(), `cannot use type oblast.record for queries: multiple columns are marked as auto-filled (id, created_at), but this SQL dialect only supports at most one per table`) + }) +} diff --git a/select.go b/select.go index e6eccb1..8aed249 100644 --- a/select.go +++ b/select.go @@ -8,8 +8,6 @@ import ( "errors" "fmt" "reflect" - - "go.xyrillian.de/oblast/internal" ) // Select executes the provided SQL query and fills an instance of the record type R for each row in the result set, @@ -79,7 +77,7 @@ func (s Store[R]) SelectWhere(db Handle, partialQuery string, args ...any) (resu return result, nil } -func startSelectQuery(db Handle, plan internal.Plan, query string, args ...any) (returnedRows *sql.Rows, indexes [][]int, returnedError error) { +func startSelectQuery(db Handle, plan plan, query string, args ...any) (returnedRows *sql.Rows, indexes [][]int, returnedError error) { rows, err := db.Query(query, args...) if err != nil { return nil, nil, fmt.Errorf("during Query(): %w", err) @@ -112,7 +110,7 @@ func startSelectQuery(db Handle, plan internal.Plan, query string, args ...any) return rows, indexes, nil } -func startSelectWhereQuery(db Handle, plan internal.Plan, partialQuery string, args ...any) (rows *sql.Rows, indexes [][]int, err error) { +func startSelectWhereQuery(db Handle, plan plan, partialQuery string, args ...any) (rows *sql.Rows, indexes [][]int, err error) { if plan.Select.Query == "" { return nil, nil, errors.New("cannot execute SelectWhere() because query could not be autogenerated") } @@ -175,7 +173,7 @@ func (s Store[R]) SelectOneWhere(db Handle, partialQuery string, args ...any) (r return } -func selectOneWhere(db Handle, plan internal.Plan, v reflect.Value, partialQuery string, args []any) error { +func selectOneWhere(db Handle, plan plan, v reflect.Value, partialQuery string, args []any) error { if plan.Select.Query == "" { return errors.New("cannot execute SelectOneWhere() because query could not be autogenerated") } -- cgit v1.2.3