diff options
| -rw-r--r-- | dialect.go | 1 | ||||
| -rw-r--r-- | internal/dialect.go | 2 | ||||
| -rw-r--r-- | internal/plan.go | 185 | ||||
| -rw-r--r-- | internal/plan_test.go | 46 |
4 files changed, 188 insertions, 46 deletions
@@ -16,6 +16,7 @@ import "go.xyrillian.de/oblast/internal" type Dialect interface { // Placeholder returns the placeholder for the i-th query argument. // Most dialects use "?", but e.g. PostgreSQL uses "$1", "$2" and so on. + // The argument numbers from 0 like a slice index. Placeholder(i int) string // QuoteIdentifier wraps the name of a column or table in quotes, diff --git a/internal/dialect.go b/internal/dialect.go index 0cf90a2..e6db5b8 100644 --- a/internal/dialect.go +++ b/internal/dialect.go @@ -20,7 +20,7 @@ type Dialect interface { // PostgresDialect is the dialect of PostgreSQL databases. type PostgresDialect struct{} -func (PostgresDialect) Placeholder(i int) string { return "$" + strconv.Itoa(i) } +func (PostgresDialect) Placeholder(i int) string { return "$" + strconv.Itoa(i+1) } func (PostgresDialect) QuoteIdentifier(name string) string { return `"` + name + `"` } func (PostgresDialect) UsesLastInsertID() bool { return false } diff --git a/internal/plan.go b/internal/plan.go index 0defd15..8fc24d8 100644 --- a/internal/plan.go +++ b/internal/plan.go @@ -15,25 +15,28 @@ import ( // Plan holds all information that we can derive from reflecting on a given type. // The queries held within are only valid within the context of a given SQL dialect. type Plan struct { - // Information extracted from applicable marker types (if any). - TableName string - PrimaryKeyColumns []string + TableName string // from info.TableNameIs marker (if any) + AllColumnNames []string // in order of struct fields + PrimaryKeyColumnNames []string // from info.PrimaryKeyIs marker (if any) + AutoColumnNames []string // subset of AllColumnNames where field has `,auto` marker // Argument for reflect.Value.FieldByIndex() for each column name. IndexByColumnName map[string][]int - // Which columns will be filled automatically by the DB during insert. - // This corresponds to having a tag like `db:"foo,auto"`. - // In DB dialects that use LastInsertID(), this list may have at most one element. - AutoColumns []string - - // Prepared queries (or empty strings if the respective query types are not - // supported for lack of the respective markers). - InsertQuery string - UpdateQuery string - DeleteQuery string - - // Arguments for reflect.Value.FieldByIndex() in the required order for p.InsertQuery. - InsertFieldOrder [][]int + + // Planned queries. + Insert PlannedQuery + Update PlannedQuery + Delete PlannedQuery +} + +// PlannedQuery appears in type Plan. +type PlannedQuery struct { + // Empty if the respective query type is not supported by this Plan + // for lack of the required marker types. + Query string + // Arguments for reflect.Value.FieldByIndex() in the correct order + // for the query arguments of the above query. + ArgumentIndexes [][]int } var ( @@ -41,6 +44,7 @@ var ( primaryKeyMarkerType = reflect.TypeFor[info.PrimaryKeyIs]() ) +// BuildPlan creates a new plan for the given struct type. func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) { if t.Kind() != reflect.Struct { return Plan{}, fmt.Errorf("expected record type to be a struct, but got kind %s (full type: %s.%s)", @@ -55,28 +59,33 @@ func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) { // collect information from markers and tags for _, index := range getAllAddressableFieldIndexes(t) { field := t.FieldByIndex(index) - fullTag := strings.TrimSpace(field.Tag.Get("db")) - if fullTag == "" || fullTag == "-" { - continue - } - tags := strings.Split(fullTag, ",") + tags := strings.Split(strings.TrimSpace(field.Tag.Get("db")), ",") - switch field.Type { - case tableNameMarkerType: + switch { + case field.Type == tableNameMarkerType: // only consider this marker when directly on `t` itself, not within embedded fields if len(index) == 1 { if len(tags) > 1 { - return Plan{}, fmt.Errorf("invalid table name %q (may not contain commas)", fullTag) + return Plan{}, fmt.Errorf("invalid table name %q (may not contain commas)", field.Tag.Get("db")) } p.TableName = tags[0] } - case primaryKeyMarkerType: + case field.Type == primaryKeyMarkerType: // only consider this marker when directly on `t` itself, not within embedded fields if len(index) == 1 { - p.PrimaryKeyColumns = tags + p.PrimaryKeyColumnNames = tags } + case field.Anonymous && field.Type.Kind() == reflect.Struct: + // for embedded struct fields, only consider their members, not the type itself, as a potential column + continue default: columnName, extraTags := tags[0], tags[1:] + if columnName == "-" { + continue + } + if columnName == "" { + columnName = field.Name + } if otherIndex := p.IndexByColumnName[columnName]; otherIndex != nil { return Plan{}, fmt.Errorf( "duplicate tag `db:%q` on field index %v, but also on field index %v", @@ -84,11 +93,12 @@ func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) { ) } p.IndexByColumnName[columnName] = index + p.AllColumnNames = append(p.AllColumnNames, columnName) for _, tag := range extraTags { switch tag { case "auto": - p.AutoColumns = append(p.AutoColumns, columnName) + p.AutoColumnNames = append(p.AutoColumnNames, columnName) default: return Plan{}, fmt.Errorf("unknown tag `db:%q` on field index %v", ","+tag, index) } @@ -97,7 +107,7 @@ func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) { } // validation: oblast.PrimaryKeyInfo must refer to columns that exist - for _, columnName := range p.PrimaryKeyColumns { + for _, columnName := range p.PrimaryKeyColumnNames { _, ok := p.IndexByColumnName[columnName] if !ok { return Plan{}, fmt.Errorf("PrimaryKeyInfo refers to column %[1]q, but no field has tag `db:%[1]q`", columnName) @@ -105,16 +115,17 @@ func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) { } // validation: LastInsertID() only works if at most one column is auto-filled - if dialect.UsesLastInsertID() && len(p.AutoColumns) > 1 { + if dialect.UsesLastInsertID() && len(p.AutoColumnNames) > 1 { return Plan{}, fmt.Errorf( "multiple columns are marked as auto-filled (%s), but this SQL dialect only supports at most one per table", - strings.Join(p.AutoColumns, ", "), + strings.Join(p.AutoColumnNames, ", "), ) } - // TODO: build INSERT query if possible - // TODO: build UPDATE query if possible - // TODO: build DELETE query if possible + // prepare query strings + p.Insert = p.buildInsertQueryIfPossible(dialect) + p.Update = p.buildUpdateQueryIfPossible(dialect) + p.Delete = p.buildDeleteQueryIfPossible(dialect) return p, nil } @@ -136,3 +147,111 @@ func getAllAddressableFieldIndexes(t reflect.Type) (result [][]int) { } return result } + +func (p Plan) getNonAutoColumnNames() []string { + result := make([]string, 0, len(p.AllColumnNames)-len(p.AutoColumnNames)) + for _, columnName := range p.AllColumnNames { + if !slices.Contains(p.AutoColumnNames, columnName) { + result = append(result, columnName) + } + } + return result +} + +func (p Plan) getNonPrimaryKeyColumnNames() []string { + result := make([]string, 0, len(p.AllColumnNames)-len(p.PrimaryKeyColumnNames)) + for _, columnName := range p.AllColumnNames { + if !slices.Contains(p.PrimaryKeyColumnNames, columnName) { + result = append(result, columnName) + } + } + return result +} + +func (p Plan) buildInsertQueryIfPossible(dialect Dialect) PlannedQuery { + if p.TableName == "" || len(p.AllColumnNames) == 0 { + return PlannedQuery{Query: ""} + } + nonAutoColumnNames := p.getNonAutoColumnNames() + if len(nonAutoColumnNames) == 0 { + return PlannedQuery{Query: ""} + } + + var ( + argumentIndexes = make([][]int, len(nonAutoColumnNames)) + quotedColumnNames = make([]string, len(nonAutoColumnNames)) + quotedPlaceholders = make([]string, len(nonAutoColumnNames)) + ) + for idx, columnName := range nonAutoColumnNames { + argumentIndexes[idx] = p.IndexByColumnName[columnName] + quotedColumnNames[idx] = dialect.QuoteIdentifier(columnName) + quotedPlaceholders[idx] = dialect.Placeholder(idx) + } + + query := fmt.Sprintf( + `INSERT INTO %s (%s) VALUES (%s)%s`, + dialect.QuoteIdentifier(p.TableName), + strings.Join(quotedColumnNames, ", "), + strings.Join(quotedPlaceholders, ", "), + dialect.InsertSuffixForAutoColumns(p.AutoColumnNames), + ) + return PlannedQuery{query, argumentIndexes} +} + +func (p Plan) buildUpdateQueryIfPossible(dialect Dialect) PlannedQuery { + if p.TableName == "" || len(p.PrimaryKeyColumnNames) == 0 { + return PlannedQuery{Query: ""} + } + nonPrimaryKeyColumnNames := p.getNonPrimaryKeyColumnNames() + if len(nonPrimaryKeyColumnNames) == 0 { + return PlannedQuery{Query: ""} + } + + var ( + setArgumentIndexes = make([][]int, len(nonPrimaryKeyColumnNames)) + setClauses = make([]string, len(nonPrimaryKeyColumnNames)) + ) + for idx, columnName := range nonPrimaryKeyColumnNames { + setArgumentIndexes[idx] = p.IndexByColumnName[columnName] + setClauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx)) + } + + var ( + whereArgumentIndexes = make([][]int, len(p.PrimaryKeyColumnNames)) + whereClauses = make([]string, len(p.PrimaryKeyColumnNames)) + ) + for idx, columnName := range p.PrimaryKeyColumnNames { + whereArgumentIndexes[idx] = p.IndexByColumnName[columnName] + whereClauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx+len(setClauses))) + } + + query := fmt.Sprintf( + `UPDATE %s SET %s WHERE %s`, + dialect.QuoteIdentifier(p.TableName), + strings.Join(setClauses, ", "), + strings.Join(whereClauses, " AND "), + ) + return PlannedQuery{query, slices.Concat(setArgumentIndexes, whereArgumentIndexes)} +} + +func (p Plan) buildDeleteQueryIfPossible(dialect Dialect) PlannedQuery { + if p.TableName == "" || len(p.PrimaryKeyColumnNames) == 0 { + return PlannedQuery{Query: ""} + } + + var ( + argumentIndexes = make([][]int, len(p.PrimaryKeyColumnNames)) + clauses = make([]string, len(p.PrimaryKeyColumnNames)) + ) + for idx, columnName := range p.PrimaryKeyColumnNames { + argumentIndexes[idx] = p.IndexByColumnName[columnName] + clauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx)) + } + + query := fmt.Sprintf( + `DELETE FROM %s WHERE %s`, + dialect.QuoteIdentifier(p.TableName), + strings.Join(clauses, " AND "), + ) + return PlannedQuery{query, argumentIndexes} +} diff --git a/internal/plan_test.go b/internal/plan_test.go index 827c6e4..570833c 100644 --- a/internal/plan_test.go +++ b/internal/plan_test.go @@ -19,8 +19,9 @@ func TestPlanFieldTraversal(t *testing.T) { info.PrimaryKeyIs `db:"id"` ID int64 `db:"id,auto"` CreatedAt time.Time `db:"created_at"` - Message string `db:"message"` - private1 bool `db:"private1"` //nolint:unused + Message string + private1 bool `db:"private1"` //nolint:unused + Ignored any `db:"-"` } // assert on interface implementations @@ -31,24 +32,40 @@ func TestPlanFieldTraversal(t *testing.T) { // check that the plan for Log: // 1. has no IndexByColumnName entries for marker types - // 2. ignores "private1" because it cannot be written through reflection - // 3. recognizes "id" as an autofilled column + // 2. uses the field name as a column name for "Message" + // 3. ignores "private1" because it cannot be written through reflection + // 4. ignores "Ignored" because its column name is "-" + // 5. recognizes "id" as an autofilled column plan, err := internal.BuildPlan(reflect.TypeFor[Log](), internal.PostgresDialect{}) if err != nil { t.Error(err) } assert.Equal(t, plan.TableName, "log_entries") - assert.DeepEqual(t, plan.PrimaryKeyColumns, []string{"id"}) - assert.DeepEqual(t, plan.AutoColumns, []string{"id"}) + assert.DeepEqual(t, plan.AllColumnNames, []string{"id", "created_at", "Message"}) + assert.DeepEqual(t, plan.PrimaryKeyColumnNames, []string{"id"}) + assert.DeepEqual(t, plan.AutoColumnNames, []string{"id"}) assert.DeepEqual(t, plan.IndexByColumnName, map[string][]int{ "id": {2}, "created_at": {3}, - "message": {4}, + "Message": {4}, }) + assert.Equal(t, plan.Insert.Query, + `INSERT INTO "log_entries" ("created_at", "Message") VALUES ($1, $2) RETURNING "id"`, + ) + assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{3}, {4}}) + assert.Equal(t, plan.Update.Query, + `UPDATE "log_entries" SET "created_at" = $1, "Message" = $2 WHERE "id" = $3`, + ) + assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{3}, {4}, {2}}) + assert.Equal(t, plan.Delete.Query, + `DELETE FROM "log_entries" WHERE "id" = $1`, + ) + assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{2}}) + type record struct { Log - Keks bool `db:"keks"` + Foo bool `db:"foo"` private2 bool `db:"private2"` //nolint:unused } @@ -61,12 +78,17 @@ func TestPlanFieldTraversal(t *testing.T) { t.Error(err) } assert.Equal(t, plan.TableName, "") - assert.DeepEqual(t, plan.PrimaryKeyColumns, nil) - assert.DeepEqual(t, plan.AutoColumns, []string{"id"}) // this is okay, it does not bear significance in practice since no queries are generated + assert.DeepEqual(t, plan.AllColumnNames, []string{"id", "created_at", "Message", "foo"}) + assert.DeepEqual(t, plan.PrimaryKeyColumnNames, nil) + assert.DeepEqual(t, plan.AutoColumnNames, []string{"id"}) // this is okay, it does not bear significance in practice since no queries are generated assert.DeepEqual(t, plan.IndexByColumnName, map[string][]int{ "id": {0, 2}, "created_at": {0, 3}, - "message": {0, 4}, - "keks": {1}, + "Message": {0, 4}, + "foo": {1}, }) + + assert.Equal(t, plan.Insert.Query, "") + assert.Equal(t, plan.Update.Query, "") + assert.Equal(t, plan.Delete.Query, "") } |
