aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStefan Majewsky <majewsky@gmx.net>2026-05-22 14:01:24 +0200
committerStefan Majewsky <majewsky@gmx.net>2026-05-22 14:01:24 +0200
commit764eaf643e323b92a616fc8e6a193855bb43d905 (patch)
tree935827e791480719a1cf63f806c7e21006a0fb19
parent091f9b68a70d617a38ddf7a662aaf351724be746 (diff)
downloadgo-oblast-764eaf643e323b92a616fc8e6a193855bb43d905.tar.gz
bring back support for LastInsertId-based INSERT
As the remaining TODO noted, this really is much more memory-efficient than QueryRow when we can use it, since it does not allocate an *sql.Rows instance inside the *sql.Row instance where we call Scan().
-rw-r--r--CHANGELOG.md8
-rw-r--r--dialect.go23
-rw-r--r--plan.go39
-rw-r--r--plan_test.go21
-rw-r--r--query.go37
-rw-r--r--query_test.go143
6 files changed, 229 insertions, 42 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index a311b01..41dcdfd 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -3,6 +3,14 @@ SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net>
SPDX-License-Identifier: Apache-2.0
-->
+# v0.10.0 (TBD)
+
+Changes:
+
+- Dialects that support it (i.e. MariaDB and SQLite) will once again prefer collecting autogenerated IDs through `LastInsertId()`.
+ RETURNING clauses will only be used when multiple fields have the `db:",auto"` tag.
+ This improves memory consumption for INSERT and UPSERT queries on those dialects.
+
# v0.9.0 (2026-05-18)
API changes:
diff --git a/dialect.go b/dialect.go
index 3c49f58..d057c8d 100644
--- a/dialect.go
+++ b/dialect.go
@@ -4,11 +4,17 @@
package oblast
import (
+ "database/sql"
"fmt"
"strconv"
"strings"
)
+var (
+ // force imports to make docstring links work
+ _ = sql.Result(nil)
+)
+
// Dialect accounts for differences between different SQL dialects
// that are relevant to query generation within Oblast.
//
@@ -27,6 +33,11 @@ type Dialect interface {
// in order to avoid the name from being interpreted as a keyword.
QuoteIdentifier(name string) string
+ // CanUseLastInsertId returns true if this type of database system can report
+ // a single auto-generated int primary key using [sql.Result.LastInsertId].
+ // If true, the RETURNING clause will be omitted for matching INSERT queries.
+ CanUseLastInsertId() bool
+
// UpsertClause generates an "ON CONFLICT" or similar clause
// that can be appended to an INSERT query to make it fall back to
// behave like UPDATE if a record with the same primary key already exists.
@@ -52,6 +63,10 @@ func (mariadbDialect) QuoteIdentifier(name string) string {
return "`" + strings.ReplaceAll(name, "`", "``") + "`"
}
+func (mariadbDialect) CanUseLastInsertId() bool {
+ return true
+}
+
func (d mariadbDialect) UpsertClause(pkColumns, otherColumns []string) string {
clauses := make([]string, max(1, len(otherColumns)))
if len(otherColumns) == 0 {
@@ -81,6 +96,10 @@ func (postgresDialect) QuoteIdentifier(name string) string {
return `"` + strings.ReplaceAll(name, `"`, `""`) + `"`
}
+func (postgresDialect) CanUseLastInsertId() bool {
+ return false
+}
+
func (d postgresDialect) UpsertClause(pkColumns, otherColumns []string) string {
quotedPkColumns := make([]string, len(pkColumns))
for idx, name := range pkColumns {
@@ -116,6 +135,10 @@ func (sqliteDialect) QuoteIdentifier(name string) string {
return `"` + strings.ReplaceAll(name, `"`, `""`) + `"`
}
+func (sqliteDialect) CanUseLastInsertId() bool {
+ return true
+}
+
func (sqliteDialect) UpsertClause(pkColumns, otherColumns []string) string {
return postgresDialect{}.UpsertClause(pkColumns, otherColumns)
}
diff --git a/plan.go b/plan.go
index 830899e..dbcc012 100644
--- a/plan.go
+++ b/plan.go
@@ -26,6 +26,16 @@ type plan struct {
// Pointer-typed fields that need to be initialized before scanning into this type.
TransparentPointerStructFields []fieldInfo
+ // Whether the INSERT query uses QueryRow or Exec.
+ // - When no auto-generated values are collected, or when a single value can be collected through LastInsertId(),
+ // this will be false because Exec() is more memory-efficient than QueryRow(); it does not have to allocate an *sql.Rows instance.
+ // - Otherwise, i.e. when auto-generated values are collected with a RETURNING clause,
+ // this will be true because Exec() does not support scanning result values.
+ InsertUsesQueryRow bool
+ // If InsertUsesQueryRow = false and a primary key is collected from LastInsertId(),
+ // this decides whether we write it with reflect.Value.SetInt() or reflect.Value.SetUint().
+ LastInsertIdIsUnsigned bool
+
// Planned queries.
Select plannedQuery // only `SELECT ... FROM ... WHERE `; user supplies the rest during Select{,One}Where()
Insert plannedQuery
@@ -187,6 +197,33 @@ func buildPlan(t reflect.Type, dialect Dialect, opts planOpts) (plan, error) {
}
}
+ // pick strategy for INSERT
+ if p.TableName != "" {
+ switch len(p.AutoColumnNames) {
+ case 0:
+ p.InsertUsesQueryRow = false
+ case 1:
+ if dialect.CanUseLastInsertId() {
+ columnName := p.AutoColumnNames[0]
+ field := t.FieldByIndex(p.IndexByColumnName[columnName])
+ switch field.Type.Kind() { //nolint:exhaustive // false positive
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ p.InsertUsesQueryRow = false
+ p.LastInsertIdIsUnsigned = false
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ p.InsertUsesQueryRow = false
+ p.LastInsertIdIsUnsigned = true
+ default:
+ p.InsertUsesQueryRow = true
+ }
+ } else {
+ p.InsertUsesQueryRow = true
+ }
+ default:
+ p.InsertUsesQueryRow = true
+ }
+ }
+
// prepare query strings
p.Select = p.buildSelectQueryIfPossible(dialect)
p.Insert = p.buildInsertQueryIfPossible(dialect, false)
@@ -282,7 +319,7 @@ func (p plan) buildInsertQueryIfPossible(dialect Dialect, isUpsert bool) planned
if isUpsert {
query += dialect.UpsertClause(p.PrimaryKeyColumnNames, p.getNonPrimaryKeyColumnNames())
}
- if len(p.AutoColumnNames) > 0 {
+ if len(p.AutoColumnNames) > 0 && p.InsertUsesQueryRow {
quotedAutoColumns := make([]string, len(p.AutoColumnNames))
for idx, name := range p.AutoColumnNames {
quotedAutoColumns[idx] = dialect.QuoteIdentifier(name)
diff --git a/plan_test.go b/plan_test.go
index f8b4fac..6c42f7a 100644
--- a/plan_test.go
+++ b/plan_test.go
@@ -74,10 +74,12 @@ func TestQueryConstructionBasic(t *testing.T) {
if err != nil {
t.Error(err)
}
+ assert.Equal(t, plan.InsertUsesQueryRow, false)
+ assert.Equal(t, plan.LastInsertIdIsUnsigned, false)
assert.Equal(t, plan.Select.Query, "SELECT `ID`, `Description`, `CreatedAt` FROM `basic_records` WHERE ")
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}})
- assert.Equal(t, plan.Insert.Query, "INSERT INTO `basic_records` (`Description`, `CreatedAt`) VALUES (?, ?) RETURNING `ID`")
+ assert.Equal(t, plan.Insert.Query, "INSERT INTO `basic_records` (`Description`, `CreatedAt`) VALUES (?, ?)")
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}})
assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}})
assert.Equal(t, plan.Upsert.Query, "")
@@ -96,6 +98,7 @@ func TestQueryConstructionBasic(t *testing.T) {
if err != nil {
t.Error(err)
}
+ assert.Equal(t, plan.InsertUsesQueryRow, true)
assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `)
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}})
@@ -118,10 +121,12 @@ func TestQueryConstructionBasic(t *testing.T) {
if err != nil {
t.Error(err)
}
+ assert.Equal(t, plan.InsertUsesQueryRow, false)
+ assert.Equal(t, plan.LastInsertIdIsUnsigned, false)
assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `)
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}})
- assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES (?, ?) RETURNING "ID"`)
+ assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES (?, ?)`)
assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}})
assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}})
assert.Equal(t, plan.Upsert.Query, "")
@@ -151,6 +156,7 @@ func TestQueryConstructionWithOnlyPrimaryKey(t *testing.T) {
if err != nil {
t.Error(err)
}
+ assert.Equal(t, plan.InsertUsesQueryRow, false)
assert.Equal(t, plan.Select.Query, "SELECT `foo_id`, `bar_id` FROM `foo_bar_relations` WHERE ")
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}})
@@ -173,6 +179,7 @@ func TestQueryConstructionWithOnlyPrimaryKey(t *testing.T) {
if err != nil {
t.Error(err)
}
+ assert.Equal(t, plan.InsertUsesQueryRow, false)
assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `)
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}})
@@ -195,6 +202,7 @@ func TestQueryConstructionWithOnlyPrimaryKey(t *testing.T) {
if err != nil {
t.Error(err)
}
+ assert.Equal(t, plan.InsertUsesQueryRow, false)
assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `)
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}})
@@ -227,6 +235,7 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) {
if err != nil {
t.Error(err)
}
+ assert.Equal(t, plan.InsertUsesQueryRow, false)
assert.Equal(t, plan.Select.Query, "SELECT `foo_id`, `bar_id` FROM `foo_bar_relations` WHERE ")
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}})
@@ -249,6 +258,7 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) {
if err != nil {
t.Error(err)
}
+ assert.Equal(t, plan.InsertUsesQueryRow, false)
assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `)
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}})
@@ -271,6 +281,7 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) {
if err != nil {
t.Error(err)
}
+ assert.Equal(t, plan.InsertUsesQueryRow, false)
assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `)
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}})
@@ -342,6 +353,7 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) {
if err != nil {
t.Error(err)
}
+ assert.Equal(t, plan.InsertUsesQueryRow, false)
assert.Equal(t, plan.Select.Query, "SELECT `group_id`, `name`, `created_at` FROM `complex_records` WHERE ")
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}})
@@ -364,6 +376,7 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) {
if err != nil {
t.Error(err)
}
+ assert.Equal(t, plan.InsertUsesQueryRow, false)
assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `)
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}})
@@ -386,6 +399,7 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) {
if err != nil {
t.Error(err)
}
+ assert.Equal(t, plan.InsertUsesQueryRow, false)
assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `)
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}})
@@ -420,6 +434,7 @@ func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) {
if err != nil {
t.Error(err)
}
+ assert.Equal(t, plan.InsertUsesQueryRow, true)
assert.Equal(t, plan.Select.Query, "SELECT `id`, `name`, `created_at` FROM `autogenerated_records` WHERE ")
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}})
@@ -442,6 +457,7 @@ func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) {
if err != nil {
t.Error(err)
}
+ assert.Equal(t, plan.InsertUsesQueryRow, true)
assert.Equal(t, plan.Select.Query, `SELECT "id", "name", "created_at" FROM "autogenerated_records" WHERE `)
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}})
@@ -464,6 +480,7 @@ func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) {
if err != nil {
t.Error(err)
}
+ assert.Equal(t, plan.InsertUsesQueryRow, true)
assert.Equal(t, plan.Select.Query, `SELECT "id", "name", "created_at" FROM "autogenerated_records" WHERE `)
assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil)
assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}})
diff --git a/query.go b/query.go
index 79b7abd..853ef37 100644
--- a/query.go
+++ b/query.go
@@ -5,6 +5,7 @@ package oblast
import (
"context"
+ "database/sql"
"fmt"
"reflect"
@@ -72,7 +73,7 @@ func (s Store[R]) insertUsing(ctx context.Context, stmt handle.Statement, db Han
if err != nil {
return newIOError(err, "Stmt.Close", stmt.Close())
}
- err = insertRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots)
+ err = insertRecord(ctx, s.plan, v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots)
if err != nil {
return newIOError(err, "Stmt.Close", stmt.Close())
}
@@ -81,7 +82,7 @@ func (s Store[R]) insertUsing(ctx context.Context, stmt handle.Statement, db Han
return newIOError(nil, "Stmt.Close", stmt.Close())
}
-func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any, scanIndexes [][]int, scanSlots []any) error {
+func insertRecord(ctx context.Context, plan plan, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any, scanIndexes [][]int, scanSlots []any) error {
for idx, index := range argumentIndexes {
argumentSlots[idx] = v.FieldByIndex(index).Interface()
}
@@ -92,16 +93,38 @@ func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt ha
}
scanSlots[idx] = f.Addr().Interface()
}
- var err error
- if len(scanSlots) == 0 {
+
+ var (
+ result sql.Result
+ err error
+ )
+ switch {
+ case len(scanSlots) == 0:
_, err = stmt.Exec(ctx, argumentSlots)
- } else {
- // TODO: using QueryRow for inserting is extremely expensive because database/sql allocates a Rows instance under the hood; other libraries are doing better by limiting themselves to ExecContext() + LastInsertId()
+ case plan.InsertUsesQueryRow:
err = stmt.QueryRow(ctx, argumentSlots, scanSlots)
+ default:
+ result, err = stmt.Exec(ctx, argumentSlots)
}
if err != nil {
return fmt.Errorf("while inserting record with idx = %d: %w", recordIndex, err)
}
+
+ if result != nil {
+ id, err := result.LastInsertId()
+ if err != nil {
+ return fmt.Errorf("while getting LastInsertId for record with idx = %d: %w", recordIndex, err)
+ }
+ if plan.LastInsertIdIsUnsigned {
+ if id < 0 {
+ return fmt.Errorf("LastInsertId() = %d for record with idx = %d cannot be converted to uint", id, recordIndex)
+ }
+ v.FieldByIndex(scanIndexes[0]).SetUint(uint64(id))
+ } else {
+ v.FieldByIndex(scanIndexes[0]).SetInt(id)
+ }
+ }
+
return nil
}
@@ -280,7 +303,7 @@ func (s Store[R]) doUpsert(ctx context.Context, db Handle, insertStmt, updateStm
}
if isInsert {
- err = insertRecord(ctx, v, idx, insertStmt, insertArgumentIndexes, insertArgumentSlots, insertScanIndexes, insertScanSlots)
+ err = insertRecord(ctx, s.plan, v, idx, insertStmt, insertArgumentIndexes, insertArgumentSlots, insertScanIndexes, insertScanSlots)
} else {
var rowsAffected int64
rowsAffected, err = updateRecord(ctx, v, idx, updateStmt, updateArgumentIndexes, updateArgumentSlots)
diff --git a/query_test.go b/query_test.go
index a67dade..6013201 100644
--- a/query_test.go
+++ b/query_test.go
@@ -21,32 +21,93 @@ func TestInsertBasic(t *testing.T) {
db := oblast.NewDB(sql.OpenDB(md))
type basicRecord struct {
- ID int64 `oblast:"id,auto"`
+ ID int64 `db:"id,auto"`
+ Name string `db:"name"`
+ }
+
+ // testing with the SQLite dialect exercises the Exec()-based codepath
+ t.Run("driver=sqlite", func(t *testing.T) {
+ store := oblast.MustNewStore[basicRecord](
+ oblast.SqliteDialect(),
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} {
+ t.Run("N="+strconv.Itoa(batchSize), func(t *testing.T) {
+ records := make([]*basicRecord, batchSize)
+ for idx := range batchSize {
+ records[idx] = &basicRecord{Name: "new"}
+ md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`).
+ ExpectExecWithArgs("new").
+ AndReturnLastInsertId(int64(42 + idx))
+ }
+ must.Succeed(t, store.Insert(ctx, db, records...))
+ for idx, r := range records {
+ assert.Equal(t, r.ID, int64(42+idx))
+ }
+ })
+ }
+ })
+
+ // testing with the Postgres dialect exercises the QueryRow()-based codepath
+ t.Run("driver=postgres", func(t *testing.T) {
+ store := oblast.MustNewStore[basicRecord](
+ oblast.PostgresDialect(),
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} {
+ t.Run("N="+strconv.Itoa(batchSize), func(t *testing.T) {
+ records := make([]*basicRecord, batchSize)
+ for idx := range batchSize {
+ records[idx] = &basicRecord{Name: "new"}
+ md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"`).
+ ExpectQueryWithArgs("new").
+ AndReturnColumns("id").
+ WithRow(int64(42 + idx))
+ }
+ must.Succeed(t, store.Insert(ctx, db, records...))
+ for idx, r := range records {
+ assert.Equal(t, r.ID, int64(42+idx))
+ }
+ })
+ }
+ })
+}
+
+func TestInsertWithUintPrimaryKey(t *testing.T) {
+ ctx := t.Context()
+ md := mock.NewDriver()
+ db := oblast.NewDB(sql.OpenDB(md))
+
+ type exoticRecord struct {
+ ID uint64 `oblast:"id,auto"`
Name string `oblast:"name"`
}
- store := oblast.MustNewStore[basicRecord](
+ store := oblast.MustNewStore[exoticRecord](
oblast.SqliteDialect(),
oblast.StructTagKeyIs("oblast"), // this test also randomly provides coverage for this option
- oblast.TableNameIs("basic_records"),
+ oblast.TableNameIs("exotic_records"),
oblast.PrimaryKeyIs("id"),
)
- for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} {
- t.Run("N="+strconv.Itoa(batchSize), func(t *testing.T) {
- records := make([]*basicRecord, batchSize)
- for idx := range batchSize {
- records[idx] = &basicRecord{Name: "new"}
- md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?) RETURNING "id"`).
- ExpectQueryWithArgs("new").
- AndReturnColumns("id").
- WithRow(int64(42 + idx))
- }
- must.Succeed(t, store.Insert(ctx, db, records...))
- for idx, r := range records {
- assert.Equal(t, r.ID, int64(42+idx))
- }
- })
- }
+ // success case: positive ID fits into uint64
+ md.ForQuery(`INSERT INTO "exotic_records" ("name") VALUES (?)`).
+ ExpectExecWithArgs("new").
+ AndReturnLastInsertId(42)
+ record := exoticRecord{Name: "new"}
+ must.Succeed(t, store.Insert(ctx, db, &record))
+ assert.Equal(t, record.ID, 42)
+
+ // error case: negative ID cannot be converted to uint64
+ md.ForQuery(`INSERT INTO "exotic_records" ("name") VALUES (?)`).
+ ExpectExecWithArgs("another").
+ AndReturnLastInsertId(-42)
+ record = exoticRecord{Name: "another"}
+ err := store.Insert(ctx, db, &record)
+ assert.ErrEqual(t, err, "LastInsertId() = -42 for record with idx = 0 cannot be converted to uint")
}
func TestUpdateBasic(t *testing.T) {
@@ -124,17 +185,15 @@ func TestUpsertBasicWithAutoColumn(t *testing.T) {
oblast.PrimaryKeyIs("id"),
)
- md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?) RETURNING "id"`).
- ExpectQueryWithArgs("first needs insert").
- AndReturnColumns("id").
- WithRow(int64(1))
+ md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`).
+ ExpectExecWithArgs("first needs insert").
+ AndReturnLastInsertId(1)
md.ForQuery(`UPDATE "basic_records" SET "name" = ? WHERE "id" = ?`).
ExpectExecWithArgs("second needs update", 2).
AndReturnRowsAffected(1)
- md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?) RETURNING "id"`).
- ExpectQueryWithArgs("third needs insert").
- AndReturnColumns("id").
- WithRow(int64(3))
+ md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`).
+ ExpectExecWithArgs("third needs insert").
+ AndReturnLastInsertId(3)
md.ForQuery(`UPDATE "basic_records" SET "name" = ? WHERE "id" = ?`).
ExpectExecWithArgs("fourth needs update", 4).
AndReturnRowsAffected(1)
@@ -208,7 +267,7 @@ func TestWriteQueriesFailDuringPrepare(t *testing.T) {
}
err := store.Insert(ctx, db, recordsForInsert...)
- baseError := `unexpected query: INSERT INTO "basic_records" ("name") VALUES (?) RETURNING "id"`
+ baseError := `unexpected query: INSERT INTO "basic_records" ("name") VALUES (?)`
if batchSize < oblast.PrepareThreshold {
assert.ErrEqual(t, err, "while inserting record with idx = 0: "+baseError)
} else {
@@ -283,10 +342,6 @@ func TestInsertFailsOnFilledAutoField(t *testing.T) {
oblast.PrimaryKeyIs("id"),
)
- md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?) RETURNING "id"`).
- ExpectQueryWithArgs("existing").
- AndReturnColumns("id").
- WithRow(42)
err := store.Insert(ctx, db, &basicRecord{ID: 23, Name: "third"})
assert.ErrEqual(t, err, `refusing to INSERT record with idx = 0 that already has non-zero values in its "auto" columns`)
}
@@ -394,6 +449,18 @@ func TestUninitializedTransparentPointerStructs(t *testing.T) {
err = nestedRecordStore.Upsert(ctx, db, &freshBrokenRecord)
assert.ErrEqual(t, err, `refusing to INSERT or UPDATE record with idx = 0: cannot access all mapped fields because field "timestamps" holds a nil pointer`)
+ // check success case on INSERT
+ now := time.Now()
+ freshIntactRecord := nestedRecord{
+ Name: "foo",
+ timestamps: &timestamps{CreatedAt: now, DeletedAt: nil},
+ }
+ md.ForQuery(`INSERT INTO "nested_records" ("name", "created_at", "deleted_at") VALUES (?, ?, ?)`).
+ ExpectExecWithArgs("foo", now, (*time.Time)(nil)).
+ AndReturnLastInsertId(1)
+ must.Succeed(t, nestedRecordStore.Insert(ctx, db, &freshIntactRecord))
+ assert.Equal(t, freshIntactRecord.ID, 1)
+
// check detection on UPDATE
existingBrokenRecord := nestedRecord{
ID: 42,
@@ -405,6 +472,18 @@ func TestUninitializedTransparentPointerStructs(t *testing.T) {
err = nestedRecordStore.Upsert(ctx, db, &freshBrokenRecord)
assert.ErrEqual(t, err, `refusing to INSERT or UPDATE record with idx = 0: cannot access all mapped fields because field "timestamps" holds a nil pointer`)
+ // check success case on UPDATE
+ now = time.Now()
+ existingIntactRecord := nestedRecord{
+ ID: 42,
+ Name: "bar",
+ timestamps: &timestamps{CreatedAt: now, DeletedAt: nil},
+ }
+ md.ForQuery(`UPDATE "nested_records" SET "name" = ?, "created_at" = ?, "deleted_at" = ? WHERE "id" = ?`).
+ ExpectExecWithArgs("bar", now, (*time.Time)(nil), 42).
+ AndReturnRowsAffected(1)
+ must.Succeed(t, nestedRecordStore.Update(ctx, db, existingIntactRecord))
+
// check that detection on DELETE does not care about transparent pointer structs as long as they do not contain PK fields
md.ForQuery(`DELETE FROM "nested_records" WHERE "id" = ?`).
ExpectExecWithArgs(42).