aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--benchmark/benchmark_test.go3
-rw-r--r--query.go21
2 files changed, 13 insertions, 11 deletions
diff --git a/benchmark/benchmark_test.go b/benchmark/benchmark_test.go
index f026ca6..e2fc05e 100644
--- a/benchmark/benchmark_test.go
+++ b/benchmark/benchmark_test.go
@@ -222,7 +222,8 @@ func BenchmarkInsertAndDeleteOne(b *testing.B) {
insertAndDeleteWithOblast := func(b *testing.B) {
record := OblastEntry{Message: "hello"}
- must.Succeed(b, store.Insert(db, &record))
+ records := must.Return(store.Insert(db, record))(b)
+ record = records[0]
if record.ID == 0 {
b.Errorf("ID was not filled!")
}
diff --git a/query.go b/query.go
index 8b0d2cd..85c2eb5 100644
--- a/query.go
+++ b/query.go
@@ -13,15 +13,16 @@ import (
//
// Fields that are declared with the "auto" tag will not be written into the DB,
// and instead their value (as auto-generated by the DB on insert) will be placed in the record.
+// On success, returns the original set of records, updated thusly.
//
// Returns an error if [NewStore] was called without the [TableNameIs] option, which is required to generate a query for this method.
-func (s Store[R]) Insert(db Handle, records ...*R) (returnedError error) {
+func (s Store[R]) Insert(db Handle, records ...R) (returnedRecords []R, returnedError error) {
// NOTE: This function body should be as short as possible to reduce the binary size after monomorphization.
// Any expression that does not depend on type R should be factored out into a reusable function.
// TODO: minimize
if s.plan.Insert.Query == "" {
- return errors.New("cannot execute Insert() because query could not be autogenerated")
+ return nil, errors.New("cannot execute Insert() because query could not be autogenerated")
}
var (
@@ -36,14 +37,14 @@ func (s Store[R]) Insert(db Handle, records ...*R) (returnedError error) {
stmt, err := db.Prepare(s.plan.Insert.Query)
if err != nil {
- return fmt.Errorf("during Prepare(): %w", err)
+ return nil, fmt.Errorf("during Prepare(): %w", err)
}
defer func() {
returnedError = mergeCloseError("Stmt", returnedError, stmt.Close())
}()
- for idx, r := range records {
- v := reflect.ValueOf(r).Elem()
+ for idx := range records {
+ v := reflect.ValueOf(&records[idx]).Elem()
for idx, index := range argumentIndexes {
argumentSlots[idx] = v.FieldByIndex(index).Addr().Interface()
}
@@ -51,17 +52,17 @@ func (s Store[R]) Insert(db Handle, records ...*R) (returnedError error) {
if s.dialect.UsesLastInsertID() {
result, err := stmt.Exec(argumentSlots...)
if err != nil {
- return fmt.Errorf("during Exec() for record with idx = %d: %w", idx, err)
+ return nil, fmt.Errorf("during Exec() for record with idx = %d: %w", idx, err)
}
id, err := result.LastInsertId()
if err != nil {
- return fmt.Errorf("during LastInsertId() for record with idx = %d: %w", idx, err)
+ return nil, fmt.Errorf("during LastInsertId() for record with idx = %d: %w", idx, err)
}
if s.plan.FillIDWithSetInt {
v.FieldByIndex(scanIndexes[0]).SetInt(id)
} else if s.plan.FillIDWithSetUint {
if id < 0 {
- return fmt.Errorf("LastInsertId() = %d for record with idx = %d cannot be converted to uint", id, idx)
+ return nil, fmt.Errorf("LastInsertId() = %d for record with idx = %d cannot be converted to uint", id, idx)
}
v.FieldByIndex(scanIndexes[0]).SetUint(uint64(id))
}
@@ -71,12 +72,12 @@ func (s Store[R]) Insert(db Handle, records ...*R) (returnedError error) {
}
err := stmt.QueryRow(argumentSlots...).Scan(scanSlots...)
if err != nil {
- return fmt.Errorf("during QueryRow() for record with idx = %d: %w", idx, err)
+ return nil, fmt.Errorf("during QueryRow() for record with idx = %d: %w", idx, err)
}
}
}
- return nil
+ return records, nil
}
// TODO: update