diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/assert/assert.go | 25 | ||||
| -rw-r--r-- | internal/dialect.go | 41 | ||||
| -rw-r--r-- | internal/plan.go | 138 | ||||
| -rw-r--r-- | internal/plan_test.go | 72 |
4 files changed, 276 insertions, 0 deletions
diff --git a/internal/assert/assert.go b/internal/assert/assert.go new file mode 100644 index 0000000..c4e7b50 --- /dev/null +++ b/internal/assert/assert.go @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net> +// SPDX-License-Identifier: Apache-2.0 + +package assert + +import ( + "reflect" + "testing" +) + +// Equal is a test assertion. +func Equal[V comparable](t *testing.T, actual, expected V) { + t.Helper() + if actual != expected { + t.Errorf("expected %#v, but got %#v", expected, actual) + } +} + +// DeepEqual is a test assertion. +func DeepEqual[V any](t *testing.T, actual, expected V) { + t.Helper() + if !reflect.DeepEqual(actual, expected) { + t.Errorf("expected %#v, but got %#v", expected, actual) + } +} diff --git a/internal/dialect.go b/internal/dialect.go new file mode 100644 index 0000000..0cf90a2 --- /dev/null +++ b/internal/dialect.go @@ -0,0 +1,41 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net> +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "strconv" + "strings" +) + +// Dialect is a copy of the interface of the same name in package oblast. +// We cannot refer to that interface within this package because that would constitute a cyclic dependency. +type Dialect interface { + Placeholder(i int) string + QuoteIdentifier(name string) string + UsesLastInsertID() bool + InsertSuffixForAutoColumns(columns []string) string +} + +// PostgresDialect is the dialect of PostgreSQL databases. +type PostgresDialect struct{} + +func (PostgresDialect) Placeholder(i int) string { return "$" + strconv.Itoa(i) } +func (PostgresDialect) QuoteIdentifier(name string) string { return `"` + name + `"` } +func (PostgresDialect) UsesLastInsertID() bool { return false } + +func (p PostgresDialect) InsertSuffixForAutoColumns(columns []string) string { + quotedColumns := make([]string, len(columns)) + for idx, name := range columns { + quotedColumns[idx] = p.QuoteIdentifier(name) + } + return ` RETURNING ` + strings.Join(quotedColumns, ", ") +} + +// SqliteDialect is the dialect of SQLite databases. +type SqliteDialect struct{} + +func (SqliteDialect) Placeholder(_ int) string { return "?" } +func (SqliteDialect) QuoteIdentifier(name string) string { return `"` + name + `"` } +func (SqliteDialect) UsesLastInsertID() bool { return true } +func (SqliteDialect) InsertSuffixForAutoColumns(columns []string) string { return "" } diff --git a/internal/plan.go b/internal/plan.go new file mode 100644 index 0000000..0defd15 --- /dev/null +++ b/internal/plan.go @@ -0,0 +1,138 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net> +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "fmt" + "reflect" + "slices" + "strings" + + "go.xyrillian.de/oblast/info" +) + +// Plan holds all information that we can derive from reflecting on a given type. +// The queries held within are only valid within the context of a given SQL dialect. +type Plan struct { + // Information extracted from applicable marker types (if any). + TableName string + PrimaryKeyColumns []string + + // Argument for reflect.Value.FieldByIndex() for each column name. + IndexByColumnName map[string][]int + // Which columns will be filled automatically by the DB during insert. + // This corresponds to having a tag like `db:"foo,auto"`. + // In DB dialects that use LastInsertID(), this list may have at most one element. + AutoColumns []string + + // Prepared queries (or empty strings if the respective query types are not + // supported for lack of the respective markers). + InsertQuery string + UpdateQuery string + DeleteQuery string + + // Arguments for reflect.Value.FieldByIndex() in the required order for p.InsertQuery. + InsertFieldOrder [][]int +} + +var ( + tableNameMarkerType = reflect.TypeFor[info.TableNameIs]() + primaryKeyMarkerType = reflect.TypeFor[info.PrimaryKeyIs]() +) + +func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) { + if t.Kind() != reflect.Struct { + return Plan{}, fmt.Errorf("expected record type to be a struct, but got kind %s (full type: %s.%s)", + t.Kind(), t.PkgPath(), t.Name()) + } + + var p = Plan{ + IndexByColumnName: make(map[string][]int), + } + + // discover addressable fields in this type, + // collect information from markers and tags + for _, index := range getAllAddressableFieldIndexes(t) { + field := t.FieldByIndex(index) + fullTag := strings.TrimSpace(field.Tag.Get("db")) + if fullTag == "" || fullTag == "-" { + continue + } + tags := strings.Split(fullTag, ",") + + switch field.Type { + case tableNameMarkerType: + // only consider this marker when directly on `t` itself, not within embedded fields + if len(index) == 1 { + if len(tags) > 1 { + return Plan{}, fmt.Errorf("invalid table name %q (may not contain commas)", fullTag) + } + p.TableName = tags[0] + } + case primaryKeyMarkerType: + // only consider this marker when directly on `t` itself, not within embedded fields + if len(index) == 1 { + p.PrimaryKeyColumns = tags + } + default: + columnName, extraTags := tags[0], tags[1:] + if otherIndex := p.IndexByColumnName[columnName]; otherIndex != nil { + return Plan{}, fmt.Errorf( + "duplicate tag `db:%q` on field index %v, but also on field index %v", + columnName, otherIndex, index, + ) + } + p.IndexByColumnName[columnName] = index + + for _, tag := range extraTags { + switch tag { + case "auto": + p.AutoColumns = append(p.AutoColumns, columnName) + default: + return Plan{}, fmt.Errorf("unknown tag `db:%q` on field index %v", ","+tag, index) + } + } + } + } + + // validation: oblast.PrimaryKeyInfo must refer to columns that exist + for _, columnName := range p.PrimaryKeyColumns { + _, ok := p.IndexByColumnName[columnName] + if !ok { + return Plan{}, fmt.Errorf("PrimaryKeyInfo refers to column %[1]q, but no field has tag `db:%[1]q`", columnName) + } + } + + // validation: LastInsertID() only works if at most one column is auto-filled + if dialect.UsesLastInsertID() && len(p.AutoColumns) > 1 { + return Plan{}, fmt.Errorf( + "multiple columns are marked as auto-filled (%s), but this SQL dialect only supports at most one per table", + strings.Join(p.AutoColumns, ", "), + ) + } + + // TODO: build INSERT query if possible + // TODO: build UPDATE query if possible + // TODO: build DELETE query if possible + + return p, nil +} + +// WARNING: Panics if t.Kind() != reflect.Struct. +func getAllAddressableFieldIndexes(t reflect.Type) (result [][]int) { + for field := range t.Fields() { + // recurse into embedded fields + if field.Anonymous && field.Type.Kind() == reflect.Struct { + for _, subindex := range getAllAddressableFieldIndexes(field.Type) { + result = append(result, append(slices.Clone(field.Index), subindex...)) + } + } + + // only fields are addressable (otherwise reflect.Value.Interface() on the field would panic) + if field.PkgPath == "" { + result = append(result, field.Index) + } + } + return result +} diff --git a/internal/plan_test.go b/internal/plan_test.go new file mode 100644 index 0000000..827c6e4 --- /dev/null +++ b/internal/plan_test.go @@ -0,0 +1,72 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net> +// SPDX-License-Identifier: Apache-2.0 + +package internal_test + +import ( + "reflect" + "testing" + "time" + + "go.xyrillian.de/oblast/info" + "go.xyrillian.de/oblast/internal" + "go.xyrillian.de/oblast/internal/assert" +) + +func TestPlanFieldTraversal(t *testing.T) { + type Log struct { + info.TableNameIs `db:"log_entries"` + info.PrimaryKeyIs `db:"id"` + ID int64 `db:"id,auto"` + CreatedAt time.Time `db:"created_at"` + Message string `db:"message"` + private1 bool `db:"private1"` //nolint:unused + } + + // assert on interface implementations + var ( + _ info.IsTable = Log{} + _ info.IsTableWithPrimaryKey = Log{} + ) + + // check that the plan for Log: + // 1. has no IndexByColumnName entries for marker types + // 2. ignores "private1" because it cannot be written through reflection + // 3. recognizes "id" as an autofilled column + plan, err := internal.BuildPlan(reflect.TypeFor[Log](), internal.PostgresDialect{}) + if err != nil { + t.Error(err) + } + assert.Equal(t, plan.TableName, "log_entries") + assert.DeepEqual(t, plan.PrimaryKeyColumns, []string{"id"}) + assert.DeepEqual(t, plan.AutoColumns, []string{"id"}) + assert.DeepEqual(t, plan.IndexByColumnName, map[string][]int{ + "id": {2}, + "created_at": {3}, + "message": {4}, + }) + + type record struct { + Log + Keks bool `db:"keks"` + private2 bool `db:"private2"` //nolint:unused + } + + // check that the plan for record: + // 1. works at all, even though it as a whole is an unexported type + // 2. traverses into Log and includes all of its fields as well + // 3. completely ignores the marker types in type Log + plan, err = internal.BuildPlan(reflect.TypeFor[record](), internal.PostgresDialect{}) + if err != nil { + t.Error(err) + } + assert.Equal(t, plan.TableName, "") + assert.DeepEqual(t, plan.PrimaryKeyColumns, nil) + assert.DeepEqual(t, plan.AutoColumns, []string{"id"}) // this is okay, it does not bear significance in practice since no queries are generated + assert.DeepEqual(t, plan.IndexByColumnName, map[string][]int{ + "id": {0, 2}, + "created_at": {0, 3}, + "message": {0, 4}, + "keks": {1}, + }) +} |
