diff options
Diffstat (limited to 'internal/plan.go')
| -rw-r--r-- | internal/plan.go | 185 |
1 files changed, 152 insertions, 33 deletions
diff --git a/internal/plan.go b/internal/plan.go index 0defd15..8fc24d8 100644 --- a/internal/plan.go +++ b/internal/plan.go @@ -15,25 +15,28 @@ import ( // Plan holds all information that we can derive from reflecting on a given type. // The queries held within are only valid within the context of a given SQL dialect. type Plan struct { - // Information extracted from applicable marker types (if any). - TableName string - PrimaryKeyColumns []string + TableName string // from info.TableNameIs marker (if any) + AllColumnNames []string // in order of struct fields + PrimaryKeyColumnNames []string // from info.PrimaryKeyIs marker (if any) + AutoColumnNames []string // subset of AllColumnNames where field has `,auto` marker // Argument for reflect.Value.FieldByIndex() for each column name. IndexByColumnName map[string][]int - // Which columns will be filled automatically by the DB during insert. - // This corresponds to having a tag like `db:"foo,auto"`. - // In DB dialects that use LastInsertID(), this list may have at most one element. - AutoColumns []string - - // Prepared queries (or empty strings if the respective query types are not - // supported for lack of the respective markers). - InsertQuery string - UpdateQuery string - DeleteQuery string - - // Arguments for reflect.Value.FieldByIndex() in the required order for p.InsertQuery. - InsertFieldOrder [][]int + + // Planned queries. + Insert PlannedQuery + Update PlannedQuery + Delete PlannedQuery +} + +// PlannedQuery appears in type Plan. +type PlannedQuery struct { + // Empty if the respective query type is not supported by this Plan + // for lack of the required marker types. + Query string + // Arguments for reflect.Value.FieldByIndex() in the correct order + // for the query arguments of the above query. + ArgumentIndexes [][]int } var ( @@ -41,6 +44,7 @@ var ( primaryKeyMarkerType = reflect.TypeFor[info.PrimaryKeyIs]() ) +// BuildPlan creates a new plan for the given struct type. func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) { if t.Kind() != reflect.Struct { return Plan{}, fmt.Errorf("expected record type to be a struct, but got kind %s (full type: %s.%s)", @@ -55,28 +59,33 @@ func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) { // collect information from markers and tags for _, index := range getAllAddressableFieldIndexes(t) { field := t.FieldByIndex(index) - fullTag := strings.TrimSpace(field.Tag.Get("db")) - if fullTag == "" || fullTag == "-" { - continue - } - tags := strings.Split(fullTag, ",") + tags := strings.Split(strings.TrimSpace(field.Tag.Get("db")), ",") - switch field.Type { - case tableNameMarkerType: + switch { + case field.Type == tableNameMarkerType: // only consider this marker when directly on `t` itself, not within embedded fields if len(index) == 1 { if len(tags) > 1 { - return Plan{}, fmt.Errorf("invalid table name %q (may not contain commas)", fullTag) + return Plan{}, fmt.Errorf("invalid table name %q (may not contain commas)", field.Tag.Get("db")) } p.TableName = tags[0] } - case primaryKeyMarkerType: + case field.Type == primaryKeyMarkerType: // only consider this marker when directly on `t` itself, not within embedded fields if len(index) == 1 { - p.PrimaryKeyColumns = tags + p.PrimaryKeyColumnNames = tags } + case field.Anonymous && field.Type.Kind() == reflect.Struct: + // for embedded struct fields, only consider their members, not the type itself, as a potential column + continue default: columnName, extraTags := tags[0], tags[1:] + if columnName == "-" { + continue + } + if columnName == "" { + columnName = field.Name + } if otherIndex := p.IndexByColumnName[columnName]; otherIndex != nil { return Plan{}, fmt.Errorf( "duplicate tag `db:%q` on field index %v, but also on field index %v", @@ -84,11 +93,12 @@ func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) { ) } p.IndexByColumnName[columnName] = index + p.AllColumnNames = append(p.AllColumnNames, columnName) for _, tag := range extraTags { switch tag { case "auto": - p.AutoColumns = append(p.AutoColumns, columnName) + p.AutoColumnNames = append(p.AutoColumnNames, columnName) default: return Plan{}, fmt.Errorf("unknown tag `db:%q` on field index %v", ","+tag, index) } @@ -97,7 +107,7 @@ func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) { } // validation: oblast.PrimaryKeyInfo must refer to columns that exist - for _, columnName := range p.PrimaryKeyColumns { + for _, columnName := range p.PrimaryKeyColumnNames { _, ok := p.IndexByColumnName[columnName] if !ok { return Plan{}, fmt.Errorf("PrimaryKeyInfo refers to column %[1]q, but no field has tag `db:%[1]q`", columnName) @@ -105,16 +115,17 @@ func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) { } // validation: LastInsertID() only works if at most one column is auto-filled - if dialect.UsesLastInsertID() && len(p.AutoColumns) > 1 { + if dialect.UsesLastInsertID() && len(p.AutoColumnNames) > 1 { return Plan{}, fmt.Errorf( "multiple columns are marked as auto-filled (%s), but this SQL dialect only supports at most one per table", - strings.Join(p.AutoColumns, ", "), + strings.Join(p.AutoColumnNames, ", "), ) } - // TODO: build INSERT query if possible - // TODO: build UPDATE query if possible - // TODO: build DELETE query if possible + // prepare query strings + p.Insert = p.buildInsertQueryIfPossible(dialect) + p.Update = p.buildUpdateQueryIfPossible(dialect) + p.Delete = p.buildDeleteQueryIfPossible(dialect) return p, nil } @@ -136,3 +147,111 @@ func getAllAddressableFieldIndexes(t reflect.Type) (result [][]int) { } return result } + +func (p Plan) getNonAutoColumnNames() []string { + result := make([]string, 0, len(p.AllColumnNames)-len(p.AutoColumnNames)) + for _, columnName := range p.AllColumnNames { + if !slices.Contains(p.AutoColumnNames, columnName) { + result = append(result, columnName) + } + } + return result +} + +func (p Plan) getNonPrimaryKeyColumnNames() []string { + result := make([]string, 0, len(p.AllColumnNames)-len(p.PrimaryKeyColumnNames)) + for _, columnName := range p.AllColumnNames { + if !slices.Contains(p.PrimaryKeyColumnNames, columnName) { + result = append(result, columnName) + } + } + return result +} + +func (p Plan) buildInsertQueryIfPossible(dialect Dialect) PlannedQuery { + if p.TableName == "" || len(p.AllColumnNames) == 0 { + return PlannedQuery{Query: ""} + } + nonAutoColumnNames := p.getNonAutoColumnNames() + if len(nonAutoColumnNames) == 0 { + return PlannedQuery{Query: ""} + } + + var ( + argumentIndexes = make([][]int, len(nonAutoColumnNames)) + quotedColumnNames = make([]string, len(nonAutoColumnNames)) + quotedPlaceholders = make([]string, len(nonAutoColumnNames)) + ) + for idx, columnName := range nonAutoColumnNames { + argumentIndexes[idx] = p.IndexByColumnName[columnName] + quotedColumnNames[idx] = dialect.QuoteIdentifier(columnName) + quotedPlaceholders[idx] = dialect.Placeholder(idx) + } + + query := fmt.Sprintf( + `INSERT INTO %s (%s) VALUES (%s)%s`, + dialect.QuoteIdentifier(p.TableName), + strings.Join(quotedColumnNames, ", "), + strings.Join(quotedPlaceholders, ", "), + dialect.InsertSuffixForAutoColumns(p.AutoColumnNames), + ) + return PlannedQuery{query, argumentIndexes} +} + +func (p Plan) buildUpdateQueryIfPossible(dialect Dialect) PlannedQuery { + if p.TableName == "" || len(p.PrimaryKeyColumnNames) == 0 { + return PlannedQuery{Query: ""} + } + nonPrimaryKeyColumnNames := p.getNonPrimaryKeyColumnNames() + if len(nonPrimaryKeyColumnNames) == 0 { + return PlannedQuery{Query: ""} + } + + var ( + setArgumentIndexes = make([][]int, len(nonPrimaryKeyColumnNames)) + setClauses = make([]string, len(nonPrimaryKeyColumnNames)) + ) + for idx, columnName := range nonPrimaryKeyColumnNames { + setArgumentIndexes[idx] = p.IndexByColumnName[columnName] + setClauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx)) + } + + var ( + whereArgumentIndexes = make([][]int, len(p.PrimaryKeyColumnNames)) + whereClauses = make([]string, len(p.PrimaryKeyColumnNames)) + ) + for idx, columnName := range p.PrimaryKeyColumnNames { + whereArgumentIndexes[idx] = p.IndexByColumnName[columnName] + whereClauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx+len(setClauses))) + } + + query := fmt.Sprintf( + `UPDATE %s SET %s WHERE %s`, + dialect.QuoteIdentifier(p.TableName), + strings.Join(setClauses, ", "), + strings.Join(whereClauses, " AND "), + ) + return PlannedQuery{query, slices.Concat(setArgumentIndexes, whereArgumentIndexes)} +} + +func (p Plan) buildDeleteQueryIfPossible(dialect Dialect) PlannedQuery { + if p.TableName == "" || len(p.PrimaryKeyColumnNames) == 0 { + return PlannedQuery{Query: ""} + } + + var ( + argumentIndexes = make([][]int, len(p.PrimaryKeyColumnNames)) + clauses = make([]string, len(p.PrimaryKeyColumnNames)) + ) + for idx, columnName := range p.PrimaryKeyColumnNames { + argumentIndexes[idx] = p.IndexByColumnName[columnName] + clauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx)) + } + + query := fmt.Sprintf( + `DELETE FROM %s WHERE %s`, + dialect.QuoteIdentifier(p.TableName), + strings.Join(clauses, " AND "), + ) + return PlannedQuery{query, argumentIndexes} +} |
