diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/plan.go | 66 | ||||
| -rw-r--r-- | internal/plan_test.go | 46 |
2 files changed, 88 insertions, 24 deletions
diff --git a/internal/plan.go b/internal/plan.go index 7dc3361..f619a5f 100644 --- a/internal/plan.go +++ b/internal/plan.go @@ -22,8 +22,12 @@ type Plan struct { // Argument for reflect.Value.FieldByIndex() for each column name. IndexByColumnName map[string][]int + // In dialects with UsesLastInsertID() == true, whether the ID column must be written with reflect.Value.SetInt() or reflect.Value.SetUint(). + FillIDWithSetUint bool + FillIDWithSetInt bool + // Planned queries. - Select PlannedQuery // only `SELECT ... FROM ...` without WHERE or any of the other clauses + Select PlannedQuery // only `SELECT ... FROM ... WHERE `; user supplies the rest during Select{,One}Where() Insert PlannedQuery Update PlannedQuery Delete PlannedQuery @@ -31,12 +35,12 @@ type Plan struct { // PlannedQuery appears in type Plan. type PlannedQuery struct { - // Empty if the respective query type is not supported by this Plan - // for lack of the required marker types. + // Empty if the respective query type is not supported by this Plan for lack of the required marker types. Query string - // Arguments for reflect.Value.FieldByIndex() in the correct order - // for the query arguments of the above query. + // Arguments for reflect.Value.FieldByIndex() in the correct order for the query arguments of the above query. ArgumentIndexes [][]int + // Arguments for reflect.Value.FieldByIndex() in the correct order for the Scan() arguments of the above query. + ScanIndexes [][]int } // PlanOpts holds additional arguments to BuildPlan(). @@ -118,12 +122,31 @@ func buildPlan(t reflect.Type, dialect Dialect, opts PlanOpts) (Plan, error) { } } - // validation: LastInsertID() only works if at most one column is auto-filled - if dialect.UsesLastInsertID() && len(p.AutoColumnNames) > 1 { - return Plan{}, fmt.Errorf( - "multiple columns are marked as auto-filled (%s), but this SQL dialect only supports at most one per table", - strings.Join(p.AutoColumnNames, ", "), - ) + // validation: LastInsertID() only works if at most one column is auto-filled, and if that column holds an integer type + if dialect.UsesLastInsertID() { + switch len(p.AutoColumnNames) { + case 0: + // nothing to check + case 1: + columnName := p.AutoColumnNames[0] + field := t.FieldByIndex(p.IndexByColumnName[columnName]) + switch field.Type.Kind() { //nolint:exhaustive // false positive + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + p.FillIDWithSetInt = true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + p.FillIDWithSetUint = true + default: + return Plan{}, fmt.Errorf( + "column is marked as auto-filled (%s), but this SQL dialect only supports auto-filling struct fields with integer types", + strings.Join(p.AutoColumnNames, ", "), + ) + } + default: + return Plan{}, fmt.Errorf( + "multiple columns are marked as auto-filled (%s), but this SQL dialect only supports at most one per table", + strings.Join(p.AutoColumnNames, ", "), + ) + } } // prepare query strings @@ -161,11 +184,11 @@ func (p Plan) buildSelectQueryIfPossible(dialect Dialect) PlannedQuery { } var ( - argumentIndexes = make([][]int, len(p.AllColumnNames)) + scanIndexes = make([][]int, len(p.AllColumnNames)) quotedColumnNames = make([]string, len(p.AllColumnNames)) ) for idx, columnName := range p.AllColumnNames { - argumentIndexes[idx] = p.IndexByColumnName[columnName] + scanIndexes[idx] = p.IndexByColumnName[columnName] quotedColumnNames[idx] = dialect.QuoteIdentifier(columnName) } @@ -174,7 +197,7 @@ func (p Plan) buildSelectQueryIfPossible(dialect Dialect) PlannedQuery { strings.Join(quotedColumnNames, ", "), dialect.QuoteIdentifier(p.TableName), ) - return PlannedQuery{query, argumentIndexes} + return PlannedQuery{query, nil, scanIndexes} } func (p Plan) buildInsertQueryIfPossible(dialect Dialect) PlannedQuery { @@ -188,6 +211,7 @@ func (p Plan) buildInsertQueryIfPossible(dialect Dialect) PlannedQuery { var ( argumentIndexes = make([][]int, len(nonAutoColumnNames)) + scanIndexes [][]int quotedColumnNames = make([]string, len(nonAutoColumnNames)) quotedPlaceholders = make([]string, len(nonAutoColumnNames)) ) @@ -196,6 +220,14 @@ func (p Plan) buildInsertQueryIfPossible(dialect Dialect) PlannedQuery { quotedColumnNames[idx] = dialect.QuoteIdentifier(columnName) quotedPlaceholders[idx] = dialect.Placeholder(idx) } + if len(p.AutoColumnNames) > 0 { + // NOTE: This is filled even if dialect.UsesLastInsertID() is false. + // We need this index to find the right value on which to run SetInt() or SetUint(). + scanIndexes = make([][]int, len(p.AutoColumnNames)) + for idx, columnName := range p.AutoColumnNames { + scanIndexes[idx] = p.IndexByColumnName[columnName] + } + } query := fmt.Sprintf( `INSERT INTO %s (%s) VALUES (%s)`, @@ -206,7 +238,7 @@ func (p Plan) buildInsertQueryIfPossible(dialect Dialect) PlannedQuery { if len(p.AutoColumnNames) > 0 { query += dialect.InsertSuffixForAutoColumns(p.AutoColumnNames) } - return PlannedQuery{query, argumentIndexes} + return PlannedQuery{query, argumentIndexes, scanIndexes} } func (p Plan) buildUpdateQueryIfPossible(dialect Dialect) PlannedQuery { @@ -242,7 +274,7 @@ func (p Plan) buildUpdateQueryIfPossible(dialect Dialect) PlannedQuery { strings.Join(setClauses, ", "), strings.Join(whereClauses, " AND "), ) - return PlannedQuery{query, slices.Concat(setArgumentIndexes, whereArgumentIndexes)} + return PlannedQuery{query, slices.Concat(setArgumentIndexes, whereArgumentIndexes), nil} } func (p Plan) buildDeleteQueryIfPossible(dialect Dialect) PlannedQuery { @@ -264,5 +296,5 @@ func (p Plan) buildDeleteQueryIfPossible(dialect Dialect) PlannedQuery { dialect.QuoteIdentifier(p.TableName), strings.Join(clauses, " AND "), ) - return PlannedQuery{query, argumentIndexes} + return PlannedQuery{query, argumentIndexes, nil} } diff --git a/internal/plan_test.go b/internal/plan_test.go index db12943..e692556 100644 --- a/internal/plan_test.go +++ b/internal/plan_test.go @@ -76,13 +76,17 @@ func TestQueryConstructionBasic(t *testing.T) { t.Error(err) } assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}, {2}}) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES ($1, $2) RETURNING "ID"`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}}) assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = $1, "CreatedAt" = $2 WHERE "ID" = $3`) assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) assert.Equal(t, plan.Delete.Query, `DELETE FROM "basic_records" WHERE "ID" = $1`) assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}}) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) }) t.Run("SqliteDialect", func(t *testing.T) { @@ -91,13 +95,17 @@ func TestQueryConstructionBasic(t *testing.T) { t.Error(err) } assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}, {2}}) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES (?, ?)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}}) assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = ?, "CreatedAt" = ? WHERE "ID" = ?`) assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) assert.Equal(t, plan.Delete.Query, `DELETE FROM "basic_records" WHERE "ID" = ?`) assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}}) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) }) } @@ -116,13 +124,17 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) { t.Error(err) } assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}}) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}}) assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES ($1, $2)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, "") assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) assert.Equal(t, plan.Delete.Query, "") assert.DeepEqual(t, plan.Delete.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) }) t.Run("SqliteDialect", func(t *testing.T) { @@ -131,13 +143,17 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) { t.Error(err) } assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}}) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}}) assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, "") assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) assert.Equal(t, plan.Delete.Query, "") assert.DeepEqual(t, plan.Delete.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) }) } @@ -157,12 +173,16 @@ func TestQueryConstructionImpossble(t *testing.T) { assert.Equal(t, plan.Select.Query, "") assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, nil) assert.Equal(t, plan.Insert.Query, "") assert.DeepEqual(t, plan.Insert.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, "") assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) assert.Equal(t, plan.Delete.Query, "") assert.DeepEqual(t, plan.Delete.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) } } @@ -187,13 +207,17 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) { t.Error(err) } assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}, {2}}) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES ($1, $2, $3)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = $1 WHERE "group_id" = $2 AND "name" = $3`) assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}}) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) assert.Equal(t, plan.Delete.Query, `DELETE FROM "complex_records" WHERE "group_id" = $1 AND "name" = $2`) assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}, {1}}) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) }) t.Run("SqliteDialect", func(t *testing.T) { @@ -202,13 +226,17 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) { t.Error(err) } assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}, {2}}) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES (?, ?, ?)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = ? WHERE "group_id" = ? AND "name" = ?`) assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}}) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) assert.Equal(t, plan.Delete.Query, `DELETE FROM "complex_records" WHERE "group_id" = ? AND "name" = ?`) assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}, {1}}) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) }) } @@ -229,13 +257,17 @@ func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) { t.Error(err) } assert.Equal(t, plan.Select.Query, `SELECT "id", "name", "created_at" FROM "autogenerated_records" WHERE `) - assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}, {2}}) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) assert.Equal(t, plan.Insert.Query, `INSERT INTO "autogenerated_records" ("name") VALUES ($1) RETURNING "id", "created_at"`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}, {2}}) assert.Equal(t, plan.Update.Query, `UPDATE "autogenerated_records" SET "name" = $1, "created_at" = $2 WHERE "id" = $3`) assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) assert.Equal(t, plan.Delete.Query, `DELETE FROM "autogenerated_records" WHERE "id" = $1`) assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}}) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) }) t.Run("SqliteDialect", func(t *testing.T) { |
