diff options
| -rw-r--r-- | dialect.go | 45 | ||||
| -rw-r--r-- | plan.go | 16 | ||||
| -rw-r--r-- | plan_test.go | 42 |
3 files changed, 99 insertions, 4 deletions
@@ -4,6 +4,7 @@ package oblast import ( + "fmt" "strconv" "strings" ) @@ -37,6 +38,12 @@ type Dialect interface { // If UsesLastInsertID is true, this is usually not needed and the empty // string can be returned. InsertSuffixForAutoColumns(columns []string) string + + // UpsertClause generates an "ON CONFLICT" or similar clause + // that can be appended to an INSERT query to make it fall back to + // behave like UPDATE if a record with the same primary key already exists. + // This is only used for record types that have a primary key. + UpsertClause(pkColumns, otherColumns []string) string } // MysqlDialect is the dialect of MySQL and MariaDB databases. @@ -51,6 +58,20 @@ func (mysqlDialect) QuoteIdentifier(name string) string { return func (mysqlDialect) UsesLastInsertID() bool { return true } func (mysqlDialect) InsertSuffixForAutoColumns(columns []string) string { return "" } +func (d mysqlDialect) UpsertClause(pkColumns, otherColumns []string) string { + clauses := make([]string, max(1, len(otherColumns))) + if len(otherColumns) == 0 { + // we need at least one UPDATE clause; if there are no non-PK columns, + // we can just use one of the PK columns, updating those is a safe no-op + clauses[0] = fmt.Sprintf(`%[1]s = VALUES(%[1]s)`, d.QuoteIdentifier(pkColumns[0])) + } else { + for idx, name := range otherColumns { + clauses[idx] = fmt.Sprintf(`%[1]s = VALUES(%[1]s)`, d.QuoteIdentifier(name)) + } + } + return ` ON DUPLICATE KEY UPDATE ` + strings.Join(clauses, ", ") +} + // PostgresDialect is the dialect of PostgreSQL databases. func PostgresDialect() Dialect { return postgresDialect{} @@ -62,14 +83,31 @@ func (postgresDialect) Placeholder(i int) string { return "$" + strcon func (postgresDialect) QuoteIdentifier(name string) string { return `"` + name + `"` } func (postgresDialect) UsesLastInsertID() bool { return false } -func (p postgresDialect) InsertSuffixForAutoColumns(columns []string) string { +func (d postgresDialect) InsertSuffixForAutoColumns(columns []string) string { quotedColumns := make([]string, len(columns)) for idx, name := range columns { - quotedColumns[idx] = p.QuoteIdentifier(name) + quotedColumns[idx] = d.QuoteIdentifier(name) } return ` RETURNING ` + strings.Join(quotedColumns, ", ") } +func (d postgresDialect) UpsertClause(pkColumns, otherColumns []string) string { + quotedPkColumns := make([]string, len(pkColumns)) + for idx, name := range pkColumns { + quotedPkColumns[idx] = d.QuoteIdentifier(name) + } + clauses := make([]string, len(otherColumns)) + for idx, name := range otherColumns { + clauses[idx] = fmt.Sprintf(`%[1]s = EXCLUDED.%[1]s`, d.QuoteIdentifier(name)) + } + if len(otherColumns) == 0 { + return fmt.Sprintf(` ON CONFLICT (%s) DO NOTHING`, strings.Join(quotedPkColumns, ", ")) + } else { + return fmt.Sprintf(` ON CONFLICT (%s) DO UPDATE SET %s`, + strings.Join(quotedPkColumns, ", "), strings.Join(clauses, ", ")) + } +} + // SqliteDialect is the dialect of SQLite databases. func SqliteDialect() Dialect { return sqliteDialect{} @@ -81,3 +119,6 @@ func (sqliteDialect) Placeholder(_ int) string { retur func (sqliteDialect) QuoteIdentifier(name string) string { return `"` + name + `"` } func (sqliteDialect) UsesLastInsertID() bool { return true } func (sqliteDialect) InsertSuffixForAutoColumns(columns []string) string { return "" } +func (sqliteDialect) UpsertClause(pkColumns, otherColumns []string) string { + return postgresDialect{}.UpsertClause(pkColumns, otherColumns) +} @@ -33,6 +33,7 @@ type plan struct { // Planned queries. Select plannedQuery // only `SELECT ... FROM ... WHERE `; user supplies the rest during Select{,One}Where() Insert plannedQuery + Upsert plannedQuery Update plannedQuery Delete plannedQuery } @@ -198,7 +199,8 @@ func buildPlan(t reflect.Type, dialect Dialect, opts planOpts) (plan, error) { // prepare query strings p.Select = p.buildSelectQueryIfPossible(dialect) - p.Insert = p.buildInsertQueryIfPossible(dialect) + p.Insert = p.buildInsertQueryIfPossible(dialect, false) + p.Upsert = p.buildInsertQueryIfPossible(dialect, true) p.Update = p.buildUpdateQueryIfPossible(dialect) p.Delete = p.buildDeleteQueryIfPossible(dialect) @@ -247,7 +249,7 @@ func (p plan) buildSelectQueryIfPossible(dialect Dialect) plannedQuery { return plannedQuery{query, nil, scanIndexes} } -func (p plan) buildInsertQueryIfPossible(dialect Dialect) plannedQuery { +func (p plan) buildInsertQueryIfPossible(dialect Dialect, isUpsert bool) plannedQuery { if p.TableName == "" || len(p.AllColumnNames) == 0 { return plannedQuery{Query: ""} } @@ -256,6 +258,13 @@ func (p plan) buildInsertQueryIfPossible(dialect Dialect) plannedQuery { return plannedQuery{Query: ""} } + // UPSERT queries specifically are only generated if we have non-auto primary keys: + // - cannot hit a key conflict if there are no keys + // - cannot hit a key conflict on insert if all keys are autogenerated (and thus we never supply them during INSERT) + if isUpsert && !slices.ContainsFunc(p.PrimaryKeyColumnNames, func(n string) bool { return !slices.Contains(p.AutoColumnNames, n) }) { + return plannedQuery{Query: ""} + } + var ( argumentIndexes = make([][]int, len(nonAutoColumnNames)) scanIndexes [][]int @@ -282,6 +291,9 @@ func (p plan) buildInsertQueryIfPossible(dialect Dialect) plannedQuery { strings.Join(quotedColumnNames, ", "), strings.Join(quotedPlaceholders, ", "), ) + if isUpsert { + query += dialect.UpsertClause(p.PrimaryKeyColumnNames, p.getNonPrimaryKeyColumnNames()) + } if len(p.AutoColumnNames) > 0 { query += dialect.InsertSuffixForAutoColumns(p.AutoColumnNames) } diff --git a/plan_test.go b/plan_test.go index 772c14a..08c7252 100644 --- a/plan_test.go +++ b/plan_test.go @@ -80,6 +80,9 @@ func TestQueryConstructionBasic(t *testing.T) { assert.Equal(t, plan.Insert.Query, "INSERT INTO `basic_records` (`Description`, `CreatedAt`) VALUES (?, ?)") assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}}) + assert.Equal(t, plan.Upsert.Query, "") + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, "UPDATE `basic_records` SET `Description` = ?, `CreatedAt` = ? WHERE `ID` = ?") assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) assert.DeepEqual(t, plan.Update.ScanIndexes, nil) @@ -99,6 +102,9 @@ func TestQueryConstructionBasic(t *testing.T) { assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES ($1, $2) RETURNING "ID"`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}}) + assert.Equal(t, plan.Upsert.Query, "") + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = $1, "CreatedAt" = $2 WHERE "ID" = $3`) assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) assert.DeepEqual(t, plan.Update.ScanIndexes, nil) @@ -118,6 +124,9 @@ func TestQueryConstructionBasic(t *testing.T) { assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES (?, ?)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}}) + assert.Equal(t, plan.Upsert.Query, "") + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = ?, "CreatedAt" = ? WHERE "ID" = ?`) assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) assert.DeepEqual(t, plan.Update.ScanIndexes, nil) @@ -148,6 +157,9 @@ func TestQueryConstructionWithOnlyPrimaryKey(t *testing.T) { assert.Equal(t, plan.Insert.Query, "INSERT INTO `foo_bar_relations` (`foo_id`, `bar_id`) VALUES (?, ?)") assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) + assert.Equal(t, plan.Upsert.Query, "INSERT INTO `foo_bar_relations` (`foo_id`, `bar_id`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `foo_id` = VALUES(`foo_id`)") + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, [][]int{{0}, {1}}) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, "") assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Update.ScanIndexes, nil) @@ -167,6 +179,9 @@ func TestQueryConstructionWithOnlyPrimaryKey(t *testing.T) { assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES ($1, $2)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) + assert.Equal(t, plan.Upsert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES ($1, $2) ON CONFLICT ("foo_id", "bar_id") DO NOTHING`) + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, [][]int{{0}, {1}}) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, "") assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Update.ScanIndexes, nil) @@ -186,6 +201,9 @@ func TestQueryConstructionWithOnlyPrimaryKey(t *testing.T) { assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) + assert.Equal(t, plan.Upsert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?) ON CONFLICT ("foo_id", "bar_id") DO NOTHING`) + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, [][]int{{0}, {1}}) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, "") assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Update.ScanIndexes, nil) @@ -215,6 +233,9 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) { assert.Equal(t, plan.Insert.Query, "INSERT INTO `foo_bar_relations` (`foo_id`, `bar_id`) VALUES (?, ?)") assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) + assert.Equal(t, plan.Upsert.Query, "") + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, "") assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Update.ScanIndexes, nil) @@ -234,6 +255,9 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) { assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES ($1, $2)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) + assert.Equal(t, plan.Upsert.Query, "") + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, "") assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Update.ScanIndexes, nil) @@ -253,6 +277,9 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) { assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) + assert.Equal(t, plan.Upsert.Query, "") + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, "") assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Update.ScanIndexes, nil) @@ -282,6 +309,9 @@ func TestQueryConstructionImpossble(t *testing.T) { assert.Equal(t, plan.Insert.Query, "") assert.DeepEqual(t, plan.Insert.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) + assert.Equal(t, plan.Upsert.Query, "") + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, "") assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Update.ScanIndexes, nil) @@ -318,6 +348,9 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) { assert.Equal(t, plan.Insert.Query, "INSERT INTO `complex_records` (`group_id`, `name`, `created_at`) VALUES (?, ?, ?)") assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) + assert.Equal(t, plan.Upsert.Query, "INSERT INTO `complex_records` (`group_id`, `name`, `created_at`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `created_at` = VALUES(`created_at`)") + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, "UPDATE `complex_records` SET `created_at` = ? WHERE `group_id` = ? AND `name` = ?") assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}}) assert.DeepEqual(t, plan.Update.ScanIndexes, nil) @@ -337,6 +370,9 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) { assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES ($1, $2, $3)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) + assert.Equal(t, plan.Upsert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES ($1, $2, $3) ON CONFLICT ("group_id", "name") DO UPDATE SET "created_at" = EXCLUDED."created_at"`) + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = $1 WHERE "group_id" = $2 AND "name" = $3`) assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}}) assert.DeepEqual(t, plan.Update.ScanIndexes, nil) @@ -356,6 +392,9 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) { assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES (?, ?, ?)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, nil) + assert.Equal(t, plan.Upsert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES (?, ?, ?) ON CONFLICT ("group_id", "name") DO UPDATE SET "created_at" = EXCLUDED."created_at"`) + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = ? WHERE "group_id" = ? AND "name" = ?`) assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}}) assert.DeepEqual(t, plan.Update.ScanIndexes, nil) @@ -392,6 +431,9 @@ func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) { assert.Equal(t, plan.Insert.Query, `INSERT INTO "autogenerated_records" ("name") VALUES ($1) RETURNING "id", "created_at"`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}, {2}}) + assert.Equal(t, plan.Upsert.Query, "") + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) assert.Equal(t, plan.Update.Query, `UPDATE "autogenerated_records" SET "name" = $1, "created_at" = $2 WHERE "id" = $3`) assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) assert.DeepEqual(t, plan.Update.ScanIndexes, nil) |
