aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGELOG.md3
-rw-r--r--TODO.md6
-rw-r--r--query.go92
-rw-r--r--query_test.go98
4 files changed, 188 insertions, 11 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 8d7c623..948bb2e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -9,6 +9,7 @@ Changes:
- `Store.Insert()` now takes its arguments by-pointer. This is probably slightly less efficient,
but significantly safer because autogenerated field values cannot be disregarded by accident.
+- Add `Store.Update()`.
- Removed support for SQL dialects that rely on LastInsertId() for ID columns.
Using a RETURNING clause to collect autogenerated field values is objectively better in every way,
and has been supported by both MariaDB and SQLite for at least six years.
@@ -18,7 +19,7 @@ Changes:
Changes:
-- Add func StructTagKeyIs.
+- Add `func StructTagKeyIs()`.
# v0.1.0 (2026-04-18)
diff --git a/TODO.md b/TODO.md
deleted file mode 100644
index f427337..0000000
--- a/TODO.md
+++ /dev/null
@@ -1,6 +0,0 @@
-<!--
-SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net>
-SPDX-License-Identifier: Apache-2.0
--->
-
-- TODO: consider adding an upsert, e.g. `func (Store[R]) InsertOrUpdate(db Handle, records ...*R) error`, that chooses based on whether any auto fields is non-zero
diff --git a/query.go b/query.go
index 2403f7e..f5f6fb7 100644
--- a/query.go
+++ b/query.go
@@ -85,6 +85,12 @@ func (s Store[R]) Insert(db Handle, records ...*R) error {
if err != nil {
return err
}
+ return s.insertUsing(stmt, db, records)
+}
+
+func (s Store[R]) insertUsing(stmt preparedStatement, db Handle, records []*R) error {
+ // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization.
+ // Any expression that does not depend on type R should be factored out into a reusable function.
var (
argumentIndexes = s.plan.Insert.ArgumentIndexes
@@ -93,8 +99,8 @@ func (s Store[R]) Insert(db Handle, records ...*R) error {
scanSlots = make([]any, len(scanIndexes))
)
- for idx := range records {
- v := reflect.ValueOf(records[idx]).Elem()
+ for idx, r := range records {
+ v := reflect.ValueOf(r).Elem()
err := insertRecord(v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots)
if err != nil {
return newIOError(err, "Stmt.Close", stmt.Close())
@@ -211,3 +217,85 @@ func deleteRecord(v reflect.Value, recordIndex int, stmt preparedStatement, argu
}
return nil
}
+
+// Upsert executes either an SQL INSERT or UPDATE statement for each of the provided records,
+// based on whether the record already exists in the DB or not.
+//
+// - For record types that have fields declared with the "auto" tag, INSERT is chosen iff those fields hold zero values.
+// Returns an error if only some of the respective fields hold zero values while others don't.
+// Returns an error if [NewStore] was called without the [TableNameIs] or [PrimaryKeyIs] options, which are both required to generate the respective queries for this method.
+// - For record types that do not have fields declared with the "auto" tag, an INSERT ... ON CONFLICT statement is used.
+// Returns an error if [NewStore] was called without the [TableNameIs] option, which is required to generate a query for this method.
+func (s Store[R]) Upsert(db Handle, records ...*R) error {
+ // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization.
+ // Any expression that does not depend on type R should be factored out into a reusable function.
+
+ if len(s.plan.AutoColumnNames) == 0 {
+ stmt, err := prepare(db, s.plan.Upsert.Query, "Upsert", len(records))
+ if err != nil {
+ return err
+ }
+ return s.insertUsing(stmt, db, records)
+ }
+
+ // TODO: respect PrepareThreshold (or not? may be too much bookkeeping overhead for not a whole lot of benefit)
+ insertStmt, err := prepare(db, s.plan.Insert.Query, "Insert", 0)
+ if err != nil {
+ return err
+ }
+ updateStmt, err := prepare(db, s.plan.Update.Query, "Update", 0)
+ if err != nil {
+ return err
+ }
+
+ var (
+ insertArgumentIndexes = s.plan.Insert.ArgumentIndexes
+ insertArgumentSlots = make([]any, len(insertArgumentIndexes))
+ insertScanIndexes = s.plan.Insert.ScanIndexes
+ insertScanSlots = make([]any, len(insertScanIndexes))
+ updateArgumentIndexes = s.plan.Update.ArgumentIndexes
+ updateArgumentSlots = make([]any, len(updateArgumentIndexes))
+ )
+
+ for idx, r := range records {
+ v := reflect.ValueOf(r).Elem()
+ isInsert, err := upsertDecideStrategy(v, idx, insertScanIndexes)
+ if err != nil {
+ return err
+ }
+
+ if isInsert {
+ err = insertRecord(v, idx, insertStmt, insertArgumentIndexes, insertArgumentSlots, insertScanIndexes, insertScanSlots)
+ } else {
+ var rowsAffected int64
+ rowsAffected, err = updateRecord(v, idx, updateStmt, updateArgumentIndexes, updateArgumentSlots)
+ if err == nil && rowsAffected == 0 {
+ err = MissingRecordError[R]{*r, s.plan}
+ }
+ }
+ if err != nil {
+ err = newIOError(err, "InsertStmt.Close", insertStmt.Close())
+ err = newIOError(err, "UpdateStmt.Close", updateStmt.Close())
+ return err
+ }
+ }
+
+ err = newIOError(err, "InsertStmt.Close", insertStmt.Close())
+ err = newIOError(err, "UpdateStmt.Close", updateStmt.Close())
+ return err
+}
+
+func upsertDecideStrategy(v reflect.Value, recordIndex int, scanIndexes [][]int) (isInsert bool, err error) {
+ var isUpdate bool
+ for _, index := range scanIndexes {
+ if v.FieldByIndex(index).IsZero() {
+ isInsert = true
+ } else {
+ isUpdate = true
+ }
+ }
+ if isInsert && isUpdate {
+ return false, fmt.Errorf(`cannot decide whether to INSERT or UPDATE record with idx = %d: some "auto" columns are zero, others are not`, recordIndex)
+ }
+ return isInsert, nil
+}
diff --git a/query_test.go b/query_test.go
index 2809f6e..6f73642 100644
--- a/query_test.go
+++ b/query_test.go
@@ -7,6 +7,7 @@ import (
"database/sql"
"strconv"
"testing"
+ "time"
"go.xyrillian.de/oblast"
"go.xyrillian.de/oblast/internal/testhelpers/assert"
@@ -105,6 +106,51 @@ func TestDeleteBasic(t *testing.T) {
}
}
+func TestUpsertBasicWithAutoColumn(t *testing.T) {
+ md := mock.NewDriver()
+ db := sql.OpenDB(md)
+
+ type basicRecord struct {
+ ID int64 `db:"id,auto"`
+ Name string `db:"name"`
+ }
+ store := oblast.MustNewStore[basicRecord](
+ oblast.SqliteDialect(),
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?) RETURNING "id"`).
+ ExpectQueryWithArgs("first needs insert").
+ AndReturnColumns("id").
+ WithRow(int64(1))
+ md.ForQuery(`UPDATE "basic_records" SET "name" = ? WHERE "id" = ?`).
+ ExpectExecWithArgs("second needs update", 2).
+ AndReturnRowsAffected(1)
+ md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?) RETURNING "id"`).
+ ExpectQueryWithArgs("third needs insert").
+ AndReturnColumns("id").
+ WithRow(int64(3))
+ md.ForQuery(`UPDATE "basic_records" SET "name" = ? WHERE "id" = ?`).
+ ExpectExecWithArgs("fourth needs update", 4).
+ AndReturnRowsAffected(1)
+
+ records := []*basicRecord{
+ {Name: "first needs insert"},
+ {ID: 2, Name: "second needs update"},
+ {Name: "third needs insert"},
+ {ID: 4, Name: "fourth needs update"},
+ }
+ must.Succeed(t, store.Upsert(db, records...))
+
+ assert.SliceDeepEqual(t, records,
+ &basicRecord{ID: 1, Name: "first needs insert"},
+ &basicRecord{ID: 2, Name: "second needs update"},
+ &basicRecord{ID: 3, Name: "third needs insert"},
+ &basicRecord{ID: 4, Name: "fourth needs update"},
+ )
+}
+
func TestWriteQueriesNotPossible(t *testing.T) {
md := mock.NewDriver()
db := sql.OpenDB(md)
@@ -122,6 +168,9 @@ func TestWriteQueriesNotPossible(t *testing.T) {
err := store.Insert(db, &r)
assert.ErrEqual(t, err, "cannot execute Insert() because query could not be autogenerated")
+ err = store.Upsert(db, &r)
+ assert.ErrEqual(t, err, "cannot execute Insert() because query could not be autogenerated")
+
r.ID = 42
err = store.Update(db, r)
assert.ErrEqual(t, err, "cannot execute Update() because query could not be autogenerated")
@@ -178,7 +227,7 @@ func TestWriteQueriesFailDuringPrepare(t *testing.T) {
}
}
-func TestUpdateFailsOnMissingRecord(t *testing.T) {
+func TestUpdateOrUpsertFailsOnMissingRecord(t *testing.T) {
md := mock.NewDriver()
db := sql.OpenDB(md)
@@ -192,6 +241,7 @@ func TestUpdateFailsOnMissingRecord(t *testing.T) {
oblast.PrimaryKeyIs("id"),
)
+ // test Update()
md.ForQuery(`UPDATE "basic_records" SET "name" = ? WHERE "id" = ?`).
ExpectExecWithArgs("changed", 42).
AndReturnRowsAffected(0)
@@ -199,6 +249,16 @@ func TestUpdateFailsOnMissingRecord(t *testing.T) {
assert.ErrEqual(t, err, "could not UPDATE record that does not exist in the database: id = 42")
_, hasCorrectType := err.(oblast.MissingRecordError[basicRecord]) //nolint:errorlint // we explicitly do not want a wrapped error
assert.Equal(t, hasCorrectType, true)
+
+ // test Upsert() -> this will not try inserting because the strategy
+ // is chosen based on the fill state of the "auto" field
+ md.ForQuery(`UPDATE "basic_records" SET "name" = ? WHERE "id" = ?`).
+ ExpectExecWithArgs("changed", 42).
+ AndReturnRowsAffected(0)
+ err = store.Upsert(db, &basicRecord{ID: 42, Name: "changed"})
+ assert.ErrEqual(t, err, "could not UPDATE record that does not exist in the database: id = 42")
+ _, hasCorrectType = err.(oblast.MissingRecordError[basicRecord]) //nolint:errorlint // we explicitly do not want a wrapped error
+ assert.Equal(t, hasCorrectType, true)
}
func TestInsertFailsOnFilledAutoField(t *testing.T) {
@@ -223,7 +283,7 @@ func TestInsertFailsOnFilledAutoField(t *testing.T) {
assert.ErrEqual(t, err, `refusing to INSERT record with idx = 0 that already has non-zero values in its "auto" columns`)
}
-func TestInsertWithNoAutoColumns(t *testing.T) {
+func TestInsertAndUpsertWithNoAutoColumns(t *testing.T) {
md := mock.NewDriver()
db := sql.OpenDB(md)
@@ -237,8 +297,42 @@ func TestInsertWithNoAutoColumns(t *testing.T) {
oblast.PrimaryKeyIs("foo_id", "bar_id"),
)
+ // test Insert()
md.ForQuery(`INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?)`).
ExpectExecWithArgs(23, 42).
AndReturnRowsAffected(1)
must.Succeed(t, store.Insert(db, &relation{23, 42}))
+
+ // test Upsert()
+ md.ForQuery(`INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?) ON CONFLICT ("foo_id", "bar_id") DO NOTHING`).
+ ExpectExecWithArgs(1, 2).
+ AndReturnRowsAffected(1)
+ md.ForQuery(`INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?) ON CONFLICT ("foo_id", "bar_id") DO NOTHING`).
+ ExpectExecWithArgs(3, 4).
+ AndReturnRowsAffected(1)
+ must.Succeed(t, store.Upsert(db, &relation{1, 2}, &relation{3, 4}))
+}
+
+func TestUpsertFailsOnMixedAutoFieldState(t *testing.T) {
+ md := mock.NewDriver()
+ db := sql.OpenDB(md)
+
+ type complexRecord struct {
+ ID int64 `db:"id,auto"`
+ Name string `db:"name"`
+ CreatedAt time.Time `db:"created_at,auto"`
+ }
+ store := oblast.MustNewStore[complexRecord](
+ oblast.SqliteDialect(),
+ oblast.TableNameIs("complex_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ brokenRecord := complexRecord{
+ ID: 42, // this looks like we need to UPDATE
+ Name: "foo",
+ CreatedAt: time.Time{}, // this looks like we need to INSERT
+ }
+ err := store.Upsert(db, &brokenRecord)
+ assert.ErrEqual(t, err, `cannot decide whether to INSERT or UPDATE record with idx = 0: some "auto" columns are zero, others are not`)
}