From d964c2be59a73e6b21ce1a8031fe913588bddf66 Mon Sep 17 00:00:00 2001 From: Stefan Majewsky Date: Thu, 16 Apr 2026 21:18:04 +0200 Subject: add Store.Update() --- benchmark/benchmark_test.go | 115 +++++++++++++++++++++++++++++++++++++++++--- oblast.go | 19 ++++++++ query.go | 54 ++++++++++++++++++++- 3 files changed, 181 insertions(+), 7 deletions(-) diff --git a/benchmark/benchmark_test.go b/benchmark/benchmark_test.go index c97b887..83832ba 100644 --- a/benchmark/benchmark_test.go +++ b/benchmark/benchmark_test.go @@ -21,6 +21,7 @@ var ( totalRecordCountForSelect = 10000 batchSizesForSelect = []int{1, 10, 100, 1000} batchSizesForInsertDelete = []int{1, 2, 4, 8, 16, 100} + batchSizesForUpdate = []int{1, 2, 4, 8, 16, 100} ) func makeTestDB(t testing.TB, recordCount int) *sql.DB { @@ -103,7 +104,7 @@ func BenchmarkSelectMany(b *testing.B) { assert.Equal(b, count, batchSize) } - // run once to prewarm caches + // run once to prewarm caches (if any) selectWithOblast(b) selectWithGorp(b) if b.Failed() { @@ -179,7 +180,7 @@ func BenchmarkSelectOne(b *testing.B) { assert.Equal(b, id, int64(recordID)) } - // run once to prewarm caches + // run once to prewarm caches (if any) selectWithOblast(b) selectWithGorp(b) if b.Failed() { @@ -212,7 +213,6 @@ func BenchmarkSelectOne(b *testing.B) { func BenchmarkInsertAndDelete(b *testing.B) { db := makeTestDB(b, 0) - // prepare the functions that will be benched store, err := oblast.NewStore[OblastEntry]( oblast.SqliteDialect(), oblast.TableNameIs("entries"), @@ -282,7 +282,7 @@ func BenchmarkInsertAndDelete(b *testing.B) { } } - // run once to prewarm caches + // run once to prewarm caches (if any) insertAndDeleteWithOblast(b) insertAndDeleteWithGorp(b) @@ -296,12 +296,12 @@ func BenchmarkInsertAndDelete(b *testing.B) { insertAndDeleteWithOblast(b) } }) - b.Run("just straight SQLite", func(b *testing.B) { + b.Run("just SQLite (straight)", func(b *testing.B) { for b.Loop() { insertAndDeleteWithStraightSqlite(b) } }) - b.Run("just prepared SQLite", func(b *testing.B) { + b.Run("just SQLite (prepared)", func(b *testing.B) { for b.Loop() { insertAndDeleteWithPreparedSqlite(b) } @@ -309,3 +309,106 @@ func BenchmarkInsertAndDelete(b *testing.B) { }) } } + +func BenchmarkUpdate(b *testing.B) { + db := makeTestDB(b, 0) + + store, err := oblast.NewStore[OblastEntry]( + oblast.SqliteDialect(), + oblast.TableNameIs("entries"), + oblast.PrimaryKeyIs("id"), + ) + if err != nil { + b.Fatal(err) + } + gdb := gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}} + gdb.AddTableWithName(GorpEntry{}, "entries").SetKeys(true, "id") + + // test with different amounts of records + for _, batchSize := range batchSizesForUpdate { + b.Run("N="+strconv.Itoa(batchSize), func(b *testing.B) { + // prepare a bunch of records that we can update, in a reproducible way + _ = must.Return(db.Exec(`DELETE FROM entries`)) + recordsForOblast := make([]OblastEntry, batchSize) + for idx := range recordsForOblast { + recordsForOblast[idx] = OblastEntry{Message: "hello"} + } + recordsForOblast = must.Return(store.Insert(db, recordsForOblast...))(b) + recordsForGorp := make([]any, batchSize) + for idx, r := range recordsForOblast { + recordsForGorp[idx] = new(GorpEntry(r)) + } + + // prepare the functions that will be benched + updateWithOblast := func(b *testing.B, message string) { + for idx := range recordsForOblast { + recordsForOblast[idx].Message = message + } + must.Succeed(b, store.Update(db, recordsForOblast...)) + } + updateWithGorp := func(b *testing.B, message string) { + for _, r := range recordsForGorp { + r.(*GorpEntry).Message = message + } + _ = must.Return(gdb.Update(recordsForGorp...))(b) + } + updateWithStraightSqlite := func(b *testing.B, message string) { + for _, r := range recordsForOblast { + _ = must.Return(db.Exec(`UPDATE entries SET message = ? WHERE id = ?`, message, r.ID))(b) + } + } + updateWithPreparedSqlite := func(b *testing.B, message string) { + stmt := must.Return(db.Prepare(`UPDATE entries SET message = ? WHERE id = ?`))(b) + for _, r := range recordsForOblast { + _ = must.Return(stmt.Exec(message, r.ID))(b) + } + } + checkRecordsUpdated := func(b *testing.B, message string) { + var count int64 + must.Succeed(b, db.QueryRow(`SELECT COUNT(*) FROM entries WHERE message = ?`, message).Scan(&count)) + assert.Equal(b, count, int64(batchSize)) + } + + // run once to prewarm caches (if any) + updateWithGorp(b, "warming up") + updateWithOblast(b, "warming up") + + b.Run("via Gorp", func(b *testing.B) { + idx := 0 + for b.Loop() { + idx++ + message := fmt.Sprintf("round %d", idx) + updateWithGorp(b, message) + checkRecordsUpdated(b, message) + } + }) + b.Run("via Oblast", func(b *testing.B) { + idx := 0 + for b.Loop() { + idx++ + message := fmt.Sprintf("round %d", idx) + updateWithOblast(b, message) + checkRecordsUpdated(b, message) + } + }) + b.Run("just SQLite (straight)", func(b *testing.B) { + idx := 0 + for b.Loop() { + idx++ + message := fmt.Sprintf("round %d", idx) + updateWithStraightSqlite(b, message) + checkRecordsUpdated(b, message) + } + }) + b.Run("just SQLite (prepared)", func(b *testing.B) { + idx := 0 + for b.Loop() { + idx++ + message := fmt.Sprintf("round %d", idx) + updateWithPreparedSqlite(b, message) + checkRecordsUpdated(b, message) + } + }) + }) + } +} diff --git a/oblast.go b/oblast.go index 52c0cfd..42305e6 100644 --- a/oblast.go +++ b/oblast.go @@ -100,6 +100,7 @@ import ( "database/sql/driver" "fmt" "reflect" + "strings" ) var ( @@ -170,3 +171,21 @@ func MustNewStore[R any](dialect Dialect, opts ...PlanOption) Store[R] { } return store } + +// MissingRecordError is returned by [Store.Update] if one of the rows to be updated does not exist in the DB. +type MissingRecordError[R any] struct { + // The record that was provided to [Store.Update], + // but for which no row with the same primary key values could be located. + Record R + plan plan +} + +// Error implements the builtin/error interface. +func (e MissingRecordError[R]) Error() string { + keyDescs := make([]string, len(e.plan.PrimaryKeyColumnNames)) + v := reflect.ValueOf(e.Record) + for idx, columnName := range e.plan.PrimaryKeyColumnNames { + keyDescs[idx] = fmt.Sprintf("%s = %#v", columnName, v.FieldByIndex(e.plan.IndexByColumnName[columnName])) + } + return fmt.Sprintf("could not UPDATE record that does not exist in the database: %s", strings.Join(keyDescs, ", ")) +} diff --git a/query.go b/query.go index 88c2987..41d61c7 100644 --- a/query.go +++ b/query.go @@ -105,7 +105,59 @@ func (s Store[R]) Insert(db Handle, records ...R) (returnedRecords []R, returned return records, nil } -// TODO: Store.Update +// Update executes an SQL UPDATE statement for each of the provided records, updating all non-primary-key columns with the values in the records. +// Returns [MissingRecordError] if any of the records does not exist in the database, that is, if for any of the records, the database contains no row with the same primary key values. +// +// Returns an error if [NewStore] was called without the [TableNameIs] or [PrimaryKeyIs] options, which are both required to generate a query for this method. +func (s Store[R]) Update(db Handle, records ...R) (returnedError error) { + if s.plan.Update.Query == "" { + return errors.New("cannot execute Update() because query could not be autogenerated") + } + + var ( + argumentIndexes = s.plan.Update.ArgumentIndexes + argumentSlots = make([]any, len(argumentIndexes)) + ) + + var stmt *sql.Stmt + if len(records) >= PrepareThreshold { + var err error + stmt, err = db.Prepare(s.plan.Update.Query) + if err != nil { + return fmt.Errorf("during Prepare(): %w", err) + } + defer func() { + returnedError = mergeCloseError("Stmt", returnedError, stmt.Close()) + }() + } + + for idx, r := range records { + v := reflect.ValueOf(&r).Elem() + for idx, index := range argumentIndexes { + argumentSlots[idx] = v.FieldByIndex(index).Addr().Interface() + } + var ( + result sql.Result + err error + ) + if stmt == nil { + result, err = db.Exec(s.plan.Update.Query, argumentSlots...) + } else { + result, err = stmt.Exec(argumentSlots...) + } + if err != nil { + return fmt.Errorf("during Exec() for record with idx = %d: %w", idx, err) + } + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("during RowsAffected() for record with idx = %d: %w", idx, err) + } + if rowsAffected == 0 { + return MissingRecordError[R]{r, s.plan} + } + } + return nil +} // Delete executes an SQL DELETE statement for each of the provided records, using their primary keys to locate the respective table rows. // -- cgit v1.2.3