aboutsummaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/dialect.go2
-rw-r--r--internal/plan.go185
-rw-r--r--internal/plan_test.go46
3 files changed, 187 insertions, 46 deletions
diff --git a/internal/dialect.go b/internal/dialect.go
index 0cf90a2..e6db5b8 100644
--- a/internal/dialect.go
+++ b/internal/dialect.go
@@ -20,7 +20,7 @@ type Dialect interface {
// PostgresDialect is the dialect of PostgreSQL databases.
type PostgresDialect struct{}
-func (PostgresDialect) Placeholder(i int) string { return "$" + strconv.Itoa(i) }
+func (PostgresDialect) Placeholder(i int) string { return "$" + strconv.Itoa(i+1) }
func (PostgresDialect) QuoteIdentifier(name string) string { return `"` + name + `"` }
func (PostgresDialect) UsesLastInsertID() bool { return false }
diff --git a/internal/plan.go b/internal/plan.go
index 0defd15..8fc24d8 100644
--- a/internal/plan.go
+++ b/internal/plan.go
@@ -15,25 +15,28 @@ import (
// Plan holds all information that we can derive from reflecting on a given type.
// The queries held within are only valid within the context of a given SQL dialect.
type Plan struct {
- // Information extracted from applicable marker types (if any).
- TableName string
- PrimaryKeyColumns []string
+ TableName string // from info.TableNameIs marker (if any)
+ AllColumnNames []string // in order of struct fields
+ PrimaryKeyColumnNames []string // from info.PrimaryKeyIs marker (if any)
+ AutoColumnNames []string // subset of AllColumnNames where field has `,auto` marker
// Argument for reflect.Value.FieldByIndex() for each column name.
IndexByColumnName map[string][]int
- // Which columns will be filled automatically by the DB during insert.
- // This corresponds to having a tag like `db:"foo,auto"`.
- // In DB dialects that use LastInsertID(), this list may have at most one element.
- AutoColumns []string
-
- // Prepared queries (or empty strings if the respective query types are not
- // supported for lack of the respective markers).
- InsertQuery string
- UpdateQuery string
- DeleteQuery string
-
- // Arguments for reflect.Value.FieldByIndex() in the required order for p.InsertQuery.
- InsertFieldOrder [][]int
+
+ // Planned queries.
+ Insert PlannedQuery
+ Update PlannedQuery
+ Delete PlannedQuery
+}
+
+// PlannedQuery appears in type Plan.
+type PlannedQuery struct {
+ // Empty if the respective query type is not supported by this Plan
+ // for lack of the required marker types.
+ Query string
+ // Arguments for reflect.Value.FieldByIndex() in the correct order
+ // for the query arguments of the above query.
+ ArgumentIndexes [][]int
}
var (
@@ -41,6 +44,7 @@ var (
primaryKeyMarkerType = reflect.TypeFor[info.PrimaryKeyIs]()
)
+// BuildPlan creates a new plan for the given struct type.
func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) {
if t.Kind() != reflect.Struct {
return Plan{}, fmt.Errorf("expected record type to be a struct, but got kind %s (full type: %s.%s)",
@@ -55,28 +59,33 @@ func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) {
// collect information from markers and tags
for _, index := range getAllAddressableFieldIndexes(t) {
field := t.FieldByIndex(index)
- fullTag := strings.TrimSpace(field.Tag.Get("db"))
- if fullTag == "" || fullTag == "-" {
- continue
- }
- tags := strings.Split(fullTag, ",")
+ tags := strings.Split(strings.TrimSpace(field.Tag.Get("db")), ",")
- switch field.Type {
- case tableNameMarkerType:
+ switch {
+ case field.Type == tableNameMarkerType:
// only consider this marker when directly on `t` itself, not within embedded fields
if len(index) == 1 {
if len(tags) > 1 {
- return Plan{}, fmt.Errorf("invalid table name %q (may not contain commas)", fullTag)
+ return Plan{}, fmt.Errorf("invalid table name %q (may not contain commas)", field.Tag.Get("db"))
}
p.TableName = tags[0]
}
- case primaryKeyMarkerType:
+ case field.Type == primaryKeyMarkerType:
// only consider this marker when directly on `t` itself, not within embedded fields
if len(index) == 1 {
- p.PrimaryKeyColumns = tags
+ p.PrimaryKeyColumnNames = tags
}
+ case field.Anonymous && field.Type.Kind() == reflect.Struct:
+ // for embedded struct fields, only consider their members, not the type itself, as a potential column
+ continue
default:
columnName, extraTags := tags[0], tags[1:]
+ if columnName == "-" {
+ continue
+ }
+ if columnName == "" {
+ columnName = field.Name
+ }
if otherIndex := p.IndexByColumnName[columnName]; otherIndex != nil {
return Plan{}, fmt.Errorf(
"duplicate tag `db:%q` on field index %v, but also on field index %v",
@@ -84,11 +93,12 @@ func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) {
)
}
p.IndexByColumnName[columnName] = index
+ p.AllColumnNames = append(p.AllColumnNames, columnName)
for _, tag := range extraTags {
switch tag {
case "auto":
- p.AutoColumns = append(p.AutoColumns, columnName)
+ p.AutoColumnNames = append(p.AutoColumnNames, columnName)
default:
return Plan{}, fmt.Errorf("unknown tag `db:%q` on field index %v", ","+tag, index)
}
@@ -97,7 +107,7 @@ func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) {
}
// validation: oblast.PrimaryKeyInfo must refer to columns that exist
- for _, columnName := range p.PrimaryKeyColumns {
+ for _, columnName := range p.PrimaryKeyColumnNames {
_, ok := p.IndexByColumnName[columnName]
if !ok {
return Plan{}, fmt.Errorf("PrimaryKeyInfo refers to column %[1]q, but no field has tag `db:%[1]q`", columnName)
@@ -105,16 +115,17 @@ func BuildPlan(t reflect.Type, dialect Dialect) (Plan, error) {
}
// validation: LastInsertID() only works if at most one column is auto-filled
- if dialect.UsesLastInsertID() && len(p.AutoColumns) > 1 {
+ if dialect.UsesLastInsertID() && len(p.AutoColumnNames) > 1 {
return Plan{}, fmt.Errorf(
"multiple columns are marked as auto-filled (%s), but this SQL dialect only supports at most one per table",
- strings.Join(p.AutoColumns, ", "),
+ strings.Join(p.AutoColumnNames, ", "),
)
}
- // TODO: build INSERT query if possible
- // TODO: build UPDATE query if possible
- // TODO: build DELETE query if possible
+ // prepare query strings
+ p.Insert = p.buildInsertQueryIfPossible(dialect)
+ p.Update = p.buildUpdateQueryIfPossible(dialect)
+ p.Delete = p.buildDeleteQueryIfPossible(dialect)
return p, nil
}
@@ -136,3 +147,111 @@ func getAllAddressableFieldIndexes(t reflect.Type) (result [][]int) {
}
return result
}
+
+func (p Plan) getNonAutoColumnNames() []string {
+ result := make([]string, 0, len(p.AllColumnNames)-len(p.AutoColumnNames))
+ for _, columnName := range p.AllColumnNames {
+ if !slices.Contains(p.AutoColumnNames, columnName) {
+ result = append(result, columnName)
+ }
+ }
+ return result
+}
+
+func (p Plan) getNonPrimaryKeyColumnNames() []string {
+ result := make([]string, 0, len(p.AllColumnNames)-len(p.PrimaryKeyColumnNames))
+ for _, columnName := range p.AllColumnNames {
+ if !slices.Contains(p.PrimaryKeyColumnNames, columnName) {
+ result = append(result, columnName)
+ }
+ }
+ return result
+}
+
+func (p Plan) buildInsertQueryIfPossible(dialect Dialect) PlannedQuery {
+ if p.TableName == "" || len(p.AllColumnNames) == 0 {
+ return PlannedQuery{Query: ""}
+ }
+ nonAutoColumnNames := p.getNonAutoColumnNames()
+ if len(nonAutoColumnNames) == 0 {
+ return PlannedQuery{Query: ""}
+ }
+
+ var (
+ argumentIndexes = make([][]int, len(nonAutoColumnNames))
+ quotedColumnNames = make([]string, len(nonAutoColumnNames))
+ quotedPlaceholders = make([]string, len(nonAutoColumnNames))
+ )
+ for idx, columnName := range nonAutoColumnNames {
+ argumentIndexes[idx] = p.IndexByColumnName[columnName]
+ quotedColumnNames[idx] = dialect.QuoteIdentifier(columnName)
+ quotedPlaceholders[idx] = dialect.Placeholder(idx)
+ }
+
+ query := fmt.Sprintf(
+ `INSERT INTO %s (%s) VALUES (%s)%s`,
+ dialect.QuoteIdentifier(p.TableName),
+ strings.Join(quotedColumnNames, ", "),
+ strings.Join(quotedPlaceholders, ", "),
+ dialect.InsertSuffixForAutoColumns(p.AutoColumnNames),
+ )
+ return PlannedQuery{query, argumentIndexes}
+}
+
+func (p Plan) buildUpdateQueryIfPossible(dialect Dialect) PlannedQuery {
+ if p.TableName == "" || len(p.PrimaryKeyColumnNames) == 0 {
+ return PlannedQuery{Query: ""}
+ }
+ nonPrimaryKeyColumnNames := p.getNonPrimaryKeyColumnNames()
+ if len(nonPrimaryKeyColumnNames) == 0 {
+ return PlannedQuery{Query: ""}
+ }
+
+ var (
+ setArgumentIndexes = make([][]int, len(nonPrimaryKeyColumnNames))
+ setClauses = make([]string, len(nonPrimaryKeyColumnNames))
+ )
+ for idx, columnName := range nonPrimaryKeyColumnNames {
+ setArgumentIndexes[idx] = p.IndexByColumnName[columnName]
+ setClauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx))
+ }
+
+ var (
+ whereArgumentIndexes = make([][]int, len(p.PrimaryKeyColumnNames))
+ whereClauses = make([]string, len(p.PrimaryKeyColumnNames))
+ )
+ for idx, columnName := range p.PrimaryKeyColumnNames {
+ whereArgumentIndexes[idx] = p.IndexByColumnName[columnName]
+ whereClauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx+len(setClauses)))
+ }
+
+ query := fmt.Sprintf(
+ `UPDATE %s SET %s WHERE %s`,
+ dialect.QuoteIdentifier(p.TableName),
+ strings.Join(setClauses, ", "),
+ strings.Join(whereClauses, " AND "),
+ )
+ return PlannedQuery{query, slices.Concat(setArgumentIndexes, whereArgumentIndexes)}
+}
+
+func (p Plan) buildDeleteQueryIfPossible(dialect Dialect) PlannedQuery {
+ if p.TableName == "" || len(p.PrimaryKeyColumnNames) == 0 {
+ return PlannedQuery{Query: ""}
+ }
+
+ var (
+ argumentIndexes = make([][]int, len(p.PrimaryKeyColumnNames))
+ clauses = make([]string, len(p.PrimaryKeyColumnNames))
+ )
+ for idx, columnName := range p.PrimaryKeyColumnNames {
+ argumentIndexes[idx] = p.IndexByColumnName[columnName]
+ clauses[idx] = fmt.Sprintf("%s = %s", dialect.QuoteIdentifier(columnName), dialect.Placeholder(idx))
+ }
+
+ query := fmt.Sprintf(
+ `DELETE FROM %s WHERE %s`,
+ dialect.QuoteIdentifier(p.TableName),
+ strings.Join(clauses, " AND "),
+ )
+ return PlannedQuery{query, argumentIndexes}
+}
diff --git a/internal/plan_test.go b/internal/plan_test.go
index 827c6e4..570833c 100644
--- a/internal/plan_test.go
+++ b/internal/plan_test.go
@@ -19,8 +19,9 @@ func TestPlanFieldTraversal(t *testing.T) {
info.PrimaryKeyIs `db:"id"`
ID int64 `db:"id,auto"`
CreatedAt time.Time `db:"created_at"`
- Message string `db:"message"`
- private1 bool `db:"private1"` //nolint:unused
+ Message string
+ private1 bool `db:"private1"` //nolint:unused
+ Ignored any `db:"-"`
}
// assert on interface implementations
@@ -31,24 +32,40 @@ func TestPlanFieldTraversal(t *testing.T) {
// check that the plan for Log:
// 1. has no IndexByColumnName entries for marker types
- // 2. ignores "private1" because it cannot be written through reflection
- // 3. recognizes "id" as an autofilled column
+ // 2. uses the field name as a column name for "Message"
+ // 3. ignores "private1" because it cannot be written through reflection
+ // 4. ignores "Ignored" because its column name is "-"
+ // 5. recognizes "id" as an autofilled column
plan, err := internal.BuildPlan(reflect.TypeFor[Log](), internal.PostgresDialect{})
if err != nil {
t.Error(err)
}
assert.Equal(t, plan.TableName, "log_entries")
- assert.DeepEqual(t, plan.PrimaryKeyColumns, []string{"id"})
- assert.DeepEqual(t, plan.AutoColumns, []string{"id"})
+ assert.DeepEqual(t, plan.AllColumnNames, []string{"id", "created_at", "Message"})
+ assert.DeepEqual(t, plan.PrimaryKeyColumnNames, []string{"id"})
+ assert.DeepEqual(t, plan.AutoColumnNames, []string{"id"})
assert.DeepEqual(t, plan.IndexByColumnName, map[string][]int{
"id": {2},
"created_at": {3},
- "message": {4},
+ "Message": {4},
})
+ assert.Equal(t, plan.Insert.Query,
+ `INSERT INTO "log_entries" ("created_at", "Message") VALUES ($1, $2) RETURNING "id"`,
+ )
+ assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{3}, {4}})
+ assert.Equal(t, plan.Update.Query,
+ `UPDATE "log_entries" SET "created_at" = $1, "Message" = $2 WHERE "id" = $3`,
+ )
+ assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{3}, {4}, {2}})
+ assert.Equal(t, plan.Delete.Query,
+ `DELETE FROM "log_entries" WHERE "id" = $1`,
+ )
+ assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{2}})
+
type record struct {
Log
- Keks bool `db:"keks"`
+ Foo bool `db:"foo"`
private2 bool `db:"private2"` //nolint:unused
}
@@ -61,12 +78,17 @@ func TestPlanFieldTraversal(t *testing.T) {
t.Error(err)
}
assert.Equal(t, plan.TableName, "")
- assert.DeepEqual(t, plan.PrimaryKeyColumns, nil)
- assert.DeepEqual(t, plan.AutoColumns, []string{"id"}) // this is okay, it does not bear significance in practice since no queries are generated
+ assert.DeepEqual(t, plan.AllColumnNames, []string{"id", "created_at", "Message", "foo"})
+ assert.DeepEqual(t, plan.PrimaryKeyColumnNames, nil)
+ assert.DeepEqual(t, plan.AutoColumnNames, []string{"id"}) // this is okay, it does not bear significance in practice since no queries are generated
assert.DeepEqual(t, plan.IndexByColumnName, map[string][]int{
"id": {0, 2},
"created_at": {0, 3},
- "message": {0, 4},
- "keks": {1},
+ "Message": {0, 4},
+ "foo": {1},
})
+
+ assert.Equal(t, plan.Insert.Query, "")
+ assert.Equal(t, plan.Update.Query, "")
+ assert.Equal(t, plan.Delete.Query, "")
}