aboutsummaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/plan.go66
-rw-r--r--internal/plan_test.go46
2 files changed, 88 insertions, 24 deletions
diff --git a/internal/plan.go b/internal/plan.go
index 7dc3361..f619a5f 100644
--- a/internal/plan.go
+++ b/internal/plan.go
@@ -22,8 +22,12 @@ type Plan struct {
// Argument for reflect.Value.FieldByIndex() for each column name.
IndexByColumnName map[string][]int
+ // In dialects with UsesLastInsertID() == true, whether the ID column must be written with reflect.Value.SetInt() or reflect.Value.SetUint().
+ FillIDWithSetUint bool
+ FillIDWithSetInt bool
+
// Planned queries.
- Select PlannedQuery // only `SELECT ... FROM ...` without WHERE or any of the other clauses
+ Select PlannedQuery // only `SELECT ... FROM ... WHERE `; user supplies the rest during Select{,One}Where()
Insert PlannedQuery
Update PlannedQuery
Delete PlannedQuery
@@ -31,12 +35,12 @@ type Plan struct {
// PlannedQuery appears in type Plan.
type PlannedQuery struct {
- // Empty if the respective query type is not supported by this Plan
- // for lack of the required marker types.
+ // Empty if the respective query type is not supported by this Plan for lack of the required marker types.
Query string
- // Arguments for reflect.Value.FieldByIndex() in the correct order
- // for the query arguments of the above query.
+ // Arguments for reflect.Value.FieldByIndex() in the correct order for the query arguments of the above query.
ArgumentIndexes [][]int
+ // Arguments for reflect.Value.FieldByIndex() in the correct order for the Scan() arguments of the above query.
+ ScanIndexes [][]int
}
// PlanOpts holds additional arguments to BuildPlan().
@@ -118,12 +122,31 @@ func buildPlan(t reflect.Type, dialect Dialect, opts PlanOpts) (Plan, error) {
}
}
- // validation: LastInsertID() only works if at most one column is auto-filled
- if dialect.UsesLastInsertID() && len(p.AutoColumnNames) > 1 {
- return Plan{}, fmt.Errorf(
- "multiple columns are marked as auto-filled (%s), but this SQL dialect only supports at most one per table",
- strings.Join(p.AutoColumnNames, ", "),
- )
+ // validation: LastInsertID() only works if at most one column is auto-filled, and if that column holds an integer type
+ if dialect.UsesLastInsertID() {
+ switch len(p.AutoColumnNames) {
+ case 0:
+ // nothing to check
+ case 1:
+ columnName := p.AutoColumnNames[0]
+ field := t.FieldByIndex(p.IndexByColumnName[columnName])
+ switch field.Type.Kind() { //nolint:exhaustive // false positive
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ p.FillIDWithSetInt = true
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ p.FillIDWithSetUint = true
+ default:
+ return Plan{}, fmt.Errorf(
+ "column is marked as auto-filled (%s), but this SQL dialect only supports auto-filling struct fields with integer types",
+ strings.Join(p.AutoColumnNames, ", "),
+ )
+ }
+ default:
+ return Plan{}, fmt.Errorf(
+ "multiple columns are marked as auto-filled (%s), but this SQL dialect only supports at most one per table",
+ strings.Join(p.AutoColumnNames, ", "),
+ )
+ }
}
// prepare query strings
@@ -161,11 +184,11 @@ func (p Plan) buildSelectQueryIfPossible(dialect Dialect) PlannedQuery {
}
var (
- argumentIndexes = make([][]int, len(p.AllColumnNames))
+ scanIndexes = make([][]int, len(p.AllColumnNames))
quotedColumnNames = make([]string, len(p.AllColumnNames))
)
for idx, columnName := range p.AllColumnNames {
- argumentIndexes[idx] = p.IndexByColumnName[columnName]
+ scanIndexes[idx] = p.IndexByColumnName[columnName]
quotedColumnNames[idx] = dialect.QuoteIdentifier(columnName)
}
@@ -174,7 +197,7 @@ func (p Plan) buildSelectQueryIfPossible(dialect Dialect) PlannedQuery {
strings.Join(quotedColumnNames, ", "),
dialect.QuoteIdentifier(p.TableName),
)
- return PlannedQuery{query, argumentIndexes}
+ return PlannedQuery{query, nil, scanIndexes}
}
func (p Plan) buildInsertQueryIfPossible(dialect Dialect) PlannedQuery {
@@ -188,6 +211,7 @@ func (p Plan) buildInsertQueryIfPossible(dialect Dialect) PlannedQuery {
var (
argumentIndexes = make([][]int, len(nonAutoColumnNames))
+ scanIndexes [][]int
quotedColumnNames = make([]string, len(nonAutoColumnNames))
quotedPlaceholders = make([]string, len(nonAutoColumnNames))
)
@@ -196,6 +220,14 @@ func (p Plan) buildInsertQueryIfPossible(dialect Dialect) PlannedQuery {
quotedColumnNames[idx] = dialect.QuoteIdentifier(columnName)
quotedPlaceholders[idx] = dialect.Placeholder(idx)
}
+ if len(p.AutoColumnNames) > 0 {
+ // NOTE: This is filled even if dialect.UsesLastInsertID() is false.
+ // We need this index to find the right value on which to run SetInt() or SetUint().
+ scanIndexes = make([][]int, len(p.AutoColumnNames))
+ for idx, columnName := range p.AutoColumnNames {
+ scanIndexes[idx] = p.IndexByColumnName[columnName]
+ }
+ }
query := fmt.Sprintf(
`INSERT INTO %s (%s) VALUES (%s)`,
@@ -206,7 +238,7 @@ func (p Plan) buildInsertQueryIfPossible(dialect Dialect) PlannedQuery {
if len(p.AutoColumnNames) > 0 {
query += dialect.InsertSuffixForAutoColumns(p.AutoColumnNames)
}
- return PlannedQuery{query, argumentIndexes}
+ return PlannedQuery{query, argumentIndexes, scanIndexes}
}
func (p Plan) buildUpdateQueryIfPossible(dialect Dialect) PlannedQuery {
@@ -242,7 +274,7 @@ func (p Plan) buildUpdateQueryIfPossible(dialect Dialect) PlannedQuery {
strings.Join(setClauses, ", "),
strings.Join(whereClauses, " AND "),
)
- return PlannedQuery{query, slices.Concat(setArgumentIndexes, whereArgumentIndexes)}
+ return PlannedQuery{query, slices.Concat(setArgumentIndexes, whereArgumentIndexes), nil}
}
func (p Plan) buildDeleteQueryIfPossible(dialect Dialect) PlannedQuery {
@@ -264,5 +296,5 @@ func (p Plan) buildDeleteQueryIfPossible(dialect Dialect) PlannedQuery {
dialect.QuoteIdentifier(p.TableName),
strings.Join(clauses, " AND "),
)
- return PlannedQuery{query, argumentIndexes}
+ return PlannedQuery{query, argumentIndexes, nil}
}
diff --git a/internal/plan_test.go b/internal/plan_test.go
index db12943..e692556 100644
--- a/internal/plan_test.go
+++ b/internal/plan_test.go
@@ -76,13 +76,17 @@ func TestQueryConstructionBasic(t *testing.T) {
t.Error(err)
}
assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `)
- assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}, {2}})
+ assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}})
assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES ($1, $2) RETURNING "ID"`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}})
+ assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}})
assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = $1, "CreatedAt" = $2 WHERE "ID" = $3`)
assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}})
+ assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
assert.Equal(t, plan.Delete.Query, `DELETE FROM "basic_records" WHERE "ID" = $1`)
assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}})
+ assert.DeepEqual(t, plan.Delete.ScanIndexes, nil)
})
t.Run("SqliteDialect", func(t *testing.T) {
@@ -91,13 +95,17 @@ func TestQueryConstructionBasic(t *testing.T) {
t.Error(err)
}
assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `)
- assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}, {2}})
+ assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}})
assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES (?, ?)`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}})
+ assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}})
assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = ?, "CreatedAt" = ? WHERE "ID" = ?`)
assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}})
+ assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
assert.Equal(t, plan.Delete.Query, `DELETE FROM "basic_records" WHERE "ID" = ?`)
assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}})
+ assert.DeepEqual(t, plan.Delete.ScanIndexes, nil)
})
}
@@ -116,13 +124,17 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) {
t.Error(err)
}
assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `)
- assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}})
+ assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}})
assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES ($1, $2)`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}})
+ assert.DeepEqual(t, plan.Insert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, "")
assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
assert.Equal(t, plan.Delete.Query, "")
assert.DeepEqual(t, plan.Delete.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Delete.ScanIndexes, nil)
})
t.Run("SqliteDialect", func(t *testing.T) {
@@ -131,13 +143,17 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) {
t.Error(err)
}
assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `)
- assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}})
+ assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}})
assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?)`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}})
+ assert.DeepEqual(t, plan.Insert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, "")
assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
assert.Equal(t, plan.Delete.Query, "")
assert.DeepEqual(t, plan.Delete.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Delete.ScanIndexes, nil)
})
}
@@ -157,12 +173,16 @@ func TestQueryConstructionImpossble(t *testing.T) {
assert.Equal(t, plan.Select.Query, "")
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Select.ScanIndexes, nil)
assert.Equal(t, plan.Insert.Query, "")
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Insert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, "")
assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
assert.Equal(t, plan.Delete.Query, "")
assert.DeepEqual(t, plan.Delete.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Delete.ScanIndexes, nil)
}
}
@@ -187,13 +207,17 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) {
t.Error(err)
}
assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `)
- assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}, {2}})
+ assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}})
assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES ($1, $2, $3)`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}})
+ assert.DeepEqual(t, plan.Insert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = $1 WHERE "group_id" = $2 AND "name" = $3`)
assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}})
+ assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
assert.Equal(t, plan.Delete.Query, `DELETE FROM "complex_records" WHERE "group_id" = $1 AND "name" = $2`)
assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}, {1}})
+ assert.DeepEqual(t, plan.Delete.ScanIndexes, nil)
})
t.Run("SqliteDialect", func(t *testing.T) {
@@ -202,13 +226,17 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) {
t.Error(err)
}
assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `)
- assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}, {2}})
+ assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}})
assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES (?, ?, ?)`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}})
+ assert.DeepEqual(t, plan.Insert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = ? WHERE "group_id" = ? AND "name" = ?`)
assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}})
+ assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
assert.Equal(t, plan.Delete.Query, `DELETE FROM "complex_records" WHERE "group_id" = ? AND "name" = ?`)
assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}, {1}})
+ assert.DeepEqual(t, plan.Delete.ScanIndexes, nil)
})
}
@@ -229,13 +257,17 @@ func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) {
t.Error(err)
}
assert.Equal(t, plan.Select.Query, `SELECT "id", "name", "created_at" FROM "autogenerated_records" WHERE `)
- assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}, {2}})
+ assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}})
assert.Equal(t, plan.Insert.Query, `INSERT INTO "autogenerated_records" ("name") VALUES ($1) RETURNING "id", "created_at"`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}})
+ assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}, {2}})
assert.Equal(t, plan.Update.Query, `UPDATE "autogenerated_records" SET "name" = $1, "created_at" = $2 WHERE "id" = $3`)
assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}})
+ assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
assert.Equal(t, plan.Delete.Query, `DELETE FROM "autogenerated_records" WHERE "id" = $1`)
assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}})
+ assert.DeepEqual(t, plan.Delete.ScanIndexes, nil)
})
t.Run("SqliteDialect", func(t *testing.T) {