aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--dialect.go45
-rw-r--r--plan.go16
-rw-r--r--plan_test.go42
3 files changed, 99 insertions, 4 deletions
diff --git a/dialect.go b/dialect.go
index d74725e..a098ca2 100644
--- a/dialect.go
+++ b/dialect.go
@@ -4,6 +4,7 @@
package oblast
import (
+ "fmt"
"strconv"
"strings"
)
@@ -37,6 +38,12 @@ type Dialect interface {
// If UsesLastInsertID is true, this is usually not needed and the empty
// string can be returned.
InsertSuffixForAutoColumns(columns []string) string
+
+ // UpsertClause generates an "ON CONFLICT" or similar clause
+ // that can be appended to an INSERT query to make it fall back to
+ // behave like UPDATE if a record with the same primary key already exists.
+ // This is only used for record types that have a primary key.
+ UpsertClause(pkColumns, otherColumns []string) string
}
// MysqlDialect is the dialect of MySQL and MariaDB databases.
@@ -51,6 +58,20 @@ func (mysqlDialect) QuoteIdentifier(name string) string { return
func (mysqlDialect) UsesLastInsertID() bool { return true }
func (mysqlDialect) InsertSuffixForAutoColumns(columns []string) string { return "" }
+func (d mysqlDialect) UpsertClause(pkColumns, otherColumns []string) string {
+ clauses := make([]string, max(1, len(otherColumns)))
+ if len(otherColumns) == 0 {
+ // we need at least one UPDATE clause; if there are no non-PK columns,
+ // we can just use one of the PK columns, updating those is a safe no-op
+ clauses[0] = fmt.Sprintf(`%[1]s = VALUES(%[1]s)`, d.QuoteIdentifier(pkColumns[0]))
+ } else {
+ for idx, name := range otherColumns {
+ clauses[idx] = fmt.Sprintf(`%[1]s = VALUES(%[1]s)`, d.QuoteIdentifier(name))
+ }
+ }
+ return ` ON DUPLICATE KEY UPDATE ` + strings.Join(clauses, ", ")
+}
+
// PostgresDialect is the dialect of PostgreSQL databases.
func PostgresDialect() Dialect {
return postgresDialect{}
@@ -62,14 +83,31 @@ func (postgresDialect) Placeholder(i int) string { return "$" + strcon
func (postgresDialect) QuoteIdentifier(name string) string { return `"` + name + `"` }
func (postgresDialect) UsesLastInsertID() bool { return false }
-func (p postgresDialect) InsertSuffixForAutoColumns(columns []string) string {
+func (d postgresDialect) InsertSuffixForAutoColumns(columns []string) string {
quotedColumns := make([]string, len(columns))
for idx, name := range columns {
- quotedColumns[idx] = p.QuoteIdentifier(name)
+ quotedColumns[idx] = d.QuoteIdentifier(name)
}
return ` RETURNING ` + strings.Join(quotedColumns, ", ")
}
+func (d postgresDialect) UpsertClause(pkColumns, otherColumns []string) string {
+ quotedPkColumns := make([]string, len(pkColumns))
+ for idx, name := range pkColumns {
+ quotedPkColumns[idx] = d.QuoteIdentifier(name)
+ }
+ clauses := make([]string, len(otherColumns))
+ for idx, name := range otherColumns {
+ clauses[idx] = fmt.Sprintf(`%[1]s = EXCLUDED.%[1]s`, d.QuoteIdentifier(name))
+ }
+ if len(otherColumns) == 0 {
+ return fmt.Sprintf(` ON CONFLICT (%s) DO NOTHING`, strings.Join(quotedPkColumns, ", "))
+ } else {
+ return fmt.Sprintf(` ON CONFLICT (%s) DO UPDATE SET %s`,
+ strings.Join(quotedPkColumns, ", "), strings.Join(clauses, ", "))
+ }
+}
+
// SqliteDialect is the dialect of SQLite databases.
func SqliteDialect() Dialect {
return sqliteDialect{}
@@ -81,3 +119,6 @@ func (sqliteDialect) Placeholder(_ int) string { retur
func (sqliteDialect) QuoteIdentifier(name string) string { return `"` + name + `"` }
func (sqliteDialect) UsesLastInsertID() bool { return true }
func (sqliteDialect) InsertSuffixForAutoColumns(columns []string) string { return "" }
+func (sqliteDialect) UpsertClause(pkColumns, otherColumns []string) string {
+ return postgresDialect{}.UpsertClause(pkColumns, otherColumns)
+}
diff --git a/plan.go b/plan.go
index 9c4da54..9e9f44c 100644
--- a/plan.go
+++ b/plan.go
@@ -33,6 +33,7 @@ type plan struct {
// Planned queries.
Select plannedQuery // only `SELECT ... FROM ... WHERE `; user supplies the rest during Select{,One}Where()
Insert plannedQuery
+ Upsert plannedQuery
Update plannedQuery
Delete plannedQuery
}
@@ -198,7 +199,8 @@ func buildPlan(t reflect.Type, dialect Dialect, opts planOpts) (plan, error) {
// prepare query strings
p.Select = p.buildSelectQueryIfPossible(dialect)
- p.Insert = p.buildInsertQueryIfPossible(dialect)
+ p.Insert = p.buildInsertQueryIfPossible(dialect, false)
+ p.Upsert = p.buildInsertQueryIfPossible(dialect, true)
p.Update = p.buildUpdateQueryIfPossible(dialect)
p.Delete = p.buildDeleteQueryIfPossible(dialect)
@@ -247,7 +249,7 @@ func (p plan) buildSelectQueryIfPossible(dialect Dialect) plannedQuery {
return plannedQuery{query, nil, scanIndexes}
}
-func (p plan) buildInsertQueryIfPossible(dialect Dialect) plannedQuery {
+func (p plan) buildInsertQueryIfPossible(dialect Dialect, isUpsert bool) plannedQuery {
if p.TableName == "" || len(p.AllColumnNames) == 0 {
return plannedQuery{Query: ""}
}
@@ -256,6 +258,13 @@ func (p plan) buildInsertQueryIfPossible(dialect Dialect) plannedQuery {
return plannedQuery{Query: ""}
}
+ // UPSERT queries specifically are only generated if we have non-auto primary keys:
+ // - cannot hit a key conflict if there are no keys
+ // - cannot hit a key conflict on insert if all keys are autogenerated (and thus we never supply them during INSERT)
+ if isUpsert && !slices.ContainsFunc(p.PrimaryKeyColumnNames, func(n string) bool { return !slices.Contains(p.AutoColumnNames, n) }) {
+ return plannedQuery{Query: ""}
+ }
+
var (
argumentIndexes = make([][]int, len(nonAutoColumnNames))
scanIndexes [][]int
@@ -282,6 +291,9 @@ func (p plan) buildInsertQueryIfPossible(dialect Dialect) plannedQuery {
strings.Join(quotedColumnNames, ", "),
strings.Join(quotedPlaceholders, ", "),
)
+ if isUpsert {
+ query += dialect.UpsertClause(p.PrimaryKeyColumnNames, p.getNonPrimaryKeyColumnNames())
+ }
if len(p.AutoColumnNames) > 0 {
query += dialect.InsertSuffixForAutoColumns(p.AutoColumnNames)
}
diff --git a/plan_test.go b/plan_test.go
index 772c14a..08c7252 100644
--- a/plan_test.go
+++ b/plan_test.go
@@ -80,6 +80,9 @@ func TestQueryConstructionBasic(t *testing.T) {
assert.Equal(t, plan.Insert.Query, "INSERT INTO `basic_records` (`Description`, `CreatedAt`) VALUES (?, ?)")
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}})
assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}})
+ assert.Equal(t, plan.Upsert.Query, "")
+ assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, "UPDATE `basic_records` SET `Description` = ?, `CreatedAt` = ? WHERE `ID` = ?")
assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}})
assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
@@ -99,6 +102,9 @@ func TestQueryConstructionBasic(t *testing.T) {
assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES ($1, $2) RETURNING "ID"`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}})
assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}})
+ assert.Equal(t, plan.Upsert.Query, "")
+ assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = $1, "CreatedAt" = $2 WHERE "ID" = $3`)
assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}})
assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
@@ -118,6 +124,9 @@ func TestQueryConstructionBasic(t *testing.T) {
assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES (?, ?)`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}})
assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}})
+ assert.Equal(t, plan.Upsert.Query, "")
+ assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = ?, "CreatedAt" = ? WHERE "ID" = ?`)
assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}})
assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
@@ -148,6 +157,9 @@ func TestQueryConstructionWithOnlyPrimaryKey(t *testing.T) {
assert.Equal(t, plan.Insert.Query, "INSERT INTO `foo_bar_relations` (`foo_id`, `bar_id`) VALUES (?, ?)")
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}})
assert.DeepEqual(t, plan.Insert.ScanIndexes, nil)
+ assert.Equal(t, plan.Upsert.Query, "INSERT INTO `foo_bar_relations` (`foo_id`, `bar_id`) VALUES (?, ?) ON DUPLICATE KEY UPDATE `foo_id` = VALUES(`foo_id`)")
+ assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, [][]int{{0}, {1}})
+ assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, "")
assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
@@ -167,6 +179,9 @@ func TestQueryConstructionWithOnlyPrimaryKey(t *testing.T) {
assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES ($1, $2)`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}})
assert.DeepEqual(t, plan.Insert.ScanIndexes, nil)
+ assert.Equal(t, plan.Upsert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES ($1, $2) ON CONFLICT ("foo_id", "bar_id") DO NOTHING`)
+ assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, [][]int{{0}, {1}})
+ assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, "")
assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
@@ -186,6 +201,9 @@ func TestQueryConstructionWithOnlyPrimaryKey(t *testing.T) {
assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?)`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}})
assert.DeepEqual(t, plan.Insert.ScanIndexes, nil)
+ assert.Equal(t, plan.Upsert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?) ON CONFLICT ("foo_id", "bar_id") DO NOTHING`)
+ assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, [][]int{{0}, {1}})
+ assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, "")
assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
@@ -215,6 +233,9 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) {
assert.Equal(t, plan.Insert.Query, "INSERT INTO `foo_bar_relations` (`foo_id`, `bar_id`) VALUES (?, ?)")
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}})
assert.DeepEqual(t, plan.Insert.ScanIndexes, nil)
+ assert.Equal(t, plan.Upsert.Query, "")
+ assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, "")
assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
@@ -234,6 +255,9 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) {
assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES ($1, $2)`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}})
assert.DeepEqual(t, plan.Insert.ScanIndexes, nil)
+ assert.Equal(t, plan.Upsert.Query, "")
+ assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, "")
assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
@@ -253,6 +277,9 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) {
assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?)`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}})
assert.DeepEqual(t, plan.Insert.ScanIndexes, nil)
+ assert.Equal(t, plan.Upsert.Query, "")
+ assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, "")
assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
@@ -282,6 +309,9 @@ func TestQueryConstructionImpossble(t *testing.T) {
assert.Equal(t, plan.Insert.Query, "")
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Insert.ScanIndexes, nil)
+ assert.Equal(t, plan.Upsert.Query, "")
+ assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, "")
assert.DeepEqual(t, plan.Update.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
@@ -318,6 +348,9 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) {
assert.Equal(t, plan.Insert.Query, "INSERT INTO `complex_records` (`group_id`, `name`, `created_at`) VALUES (?, ?, ?)")
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}})
assert.DeepEqual(t, plan.Insert.ScanIndexes, nil)
+ assert.Equal(t, plan.Upsert.Query, "INSERT INTO `complex_records` (`group_id`, `name`, `created_at`) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE `created_at` = VALUES(`created_at`)")
+ assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, [][]int{{0}, {1}, {2}})
+ assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, "UPDATE `complex_records` SET `created_at` = ? WHERE `group_id` = ? AND `name` = ?")
assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}})
assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
@@ -337,6 +370,9 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) {
assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES ($1, $2, $3)`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}})
assert.DeepEqual(t, plan.Insert.ScanIndexes, nil)
+ assert.Equal(t, plan.Upsert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES ($1, $2, $3) ON CONFLICT ("group_id", "name") DO UPDATE SET "created_at" = EXCLUDED."created_at"`)
+ assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, [][]int{{0}, {1}, {2}})
+ assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = $1 WHERE "group_id" = $2 AND "name" = $3`)
assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}})
assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
@@ -356,6 +392,9 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) {
assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES (?, ?, ?)`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}})
assert.DeepEqual(t, plan.Insert.ScanIndexes, nil)
+ assert.Equal(t, plan.Upsert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES (?, ?, ?) ON CONFLICT ("group_id", "name") DO UPDATE SET "created_at" = EXCLUDED."created_at"`)
+ assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, [][]int{{0}, {1}, {2}})
+ assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = ? WHERE "group_id" = ? AND "name" = ?`)
assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{2}, {0}, {1}})
assert.DeepEqual(t, plan.Update.ScanIndexes, nil)
@@ -392,6 +431,9 @@ func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) {
assert.Equal(t, plan.Insert.Query, `INSERT INTO "autogenerated_records" ("name") VALUES ($1) RETURNING "id", "created_at"`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}})
assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}, {2}})
+ assert.Equal(t, plan.Upsert.Query, "")
+ assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil)
+ assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil)
assert.Equal(t, plan.Update.Query, `UPDATE "autogenerated_records" SET "name" = $1, "created_at" = $2 WHERE "id" = $3`)
assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}})
assert.DeepEqual(t, plan.Update.ScanIndexes, nil)