diff options
| author | Stefan Majewsky <majewsky@gmx.net> | 2026-05-22 14:01:24 +0200 |
|---|---|---|
| committer | Stefan Majewsky <majewsky@gmx.net> | 2026-05-22 14:01:24 +0200 |
| commit | 764eaf643e323b92a616fc8e6a193855bb43d905 (patch) | |
| tree | 935827e791480719a1cf63f806c7e21006a0fb19 | |
| parent | 091f9b68a70d617a38ddf7a662aaf351724be746 (diff) | |
| download | go-oblast-764eaf643e323b92a616fc8e6a193855bb43d905.tar.gz | |
bring back support for LastInsertId-based INSERT
As the remaining TODO noted, this really is much more memory-efficient
than QueryRow when we can use it, since it does not allocate an
*sql.Rows instance inside the *sql.Row instance where we call Scan().
| -rw-r--r-- | CHANGELOG.md | 8 | ||||
| -rw-r--r-- | dialect.go | 23 | ||||
| -rw-r--r-- | plan.go | 39 | ||||
| -rw-r--r-- | plan_test.go | 21 | ||||
| -rw-r--r-- | query.go | 37 | ||||
| -rw-r--r-- | query_test.go | 143 |
6 files changed, 229 insertions, 42 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index a311b01..41dcdfd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,14 @@ SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net> SPDX-License-Identifier: Apache-2.0 --> +# v0.10.0 (TBD) + +Changes: + +- Dialects that support it (i.e. MariaDB and SQLite) will once again prefer collecting autogenerated IDs through `LastInsertId()`. + RETURNING clauses will only be used when multiple fields have the `db:",auto"` tag. + This improves memory consumption for INSERT and UPSERT queries on those dialects. + # v0.9.0 (2026-05-18) API changes: @@ -4,11 +4,17 @@ package oblast import ( + "database/sql" "fmt" "strconv" "strings" ) +var ( + // force imports to make docstring links work + _ = sql.Result(nil) +) + // Dialect accounts for differences between different SQL dialects // that are relevant to query generation within Oblast. // @@ -27,6 +33,11 @@ type Dialect interface { // in order to avoid the name from being interpreted as a keyword. QuoteIdentifier(name string) string + // CanUseLastInsertId returns true if this type of database system can report + // a single auto-generated int primary key using [sql.Result.LastInsertId]. + // If true, the RETURNING clause will be omitted for matching INSERT queries. + CanUseLastInsertId() bool + // UpsertClause generates an "ON CONFLICT" or similar clause // that can be appended to an INSERT query to make it fall back to // behave like UPDATE if a record with the same primary key already exists. @@ -52,6 +63,10 @@ func (mariadbDialect) QuoteIdentifier(name string) string { return "`" + strings.ReplaceAll(name, "`", "``") + "`" } +func (mariadbDialect) CanUseLastInsertId() bool { + return true +} + func (d mariadbDialect) UpsertClause(pkColumns, otherColumns []string) string { clauses := make([]string, max(1, len(otherColumns))) if len(otherColumns) == 0 { @@ -81,6 +96,10 @@ func (postgresDialect) QuoteIdentifier(name string) string { return `"` + strings.ReplaceAll(name, `"`, `""`) + `"` } +func (postgresDialect) CanUseLastInsertId() bool { + return false +} + func (d postgresDialect) UpsertClause(pkColumns, otherColumns []string) string { quotedPkColumns := make([]string, len(pkColumns)) for idx, name := range pkColumns { @@ -116,6 +135,10 @@ func (sqliteDialect) QuoteIdentifier(name string) string { return `"` + strings.ReplaceAll(name, `"`, `""`) + `"` } +func (sqliteDialect) CanUseLastInsertId() bool { + return true +} + func (sqliteDialect) UpsertClause(pkColumns, otherColumns []string) string { return postgresDialect{}.UpsertClause(pkColumns, otherColumns) } @@ -26,6 +26,16 @@ type plan struct { // Pointer-typed fields that need to be initialized before scanning into this type. TransparentPointerStructFields []fieldInfo + // Whether the INSERT query uses QueryRow or Exec. + // - When no auto-generated values are collected, or when a single value can be collected through LastInsertId(), + // this will be false because Exec() is more memory-efficient than QueryRow(); it does not have to allocate an *sql.Rows instance. + // - Otherwise, i.e. when auto-generated values are collected with a RETURNING clause, + // this will be true because Exec() does not support scanning result values. + InsertUsesQueryRow bool + // If InsertUsesQueryRow = false and a primary key is collected from LastInsertId(), + // this decides whether we write it with reflect.Value.SetInt() or reflect.Value.SetUint(). + LastInsertIdIsUnsigned bool + // Planned queries. Select plannedQuery // only `SELECT ... FROM ... WHERE `; user supplies the rest during Select{,One}Where() Insert plannedQuery @@ -187,6 +197,33 @@ func buildPlan(t reflect.Type, dialect Dialect, opts planOpts) (plan, error) { } } + // pick strategy for INSERT + if p.TableName != "" { + switch len(p.AutoColumnNames) { + case 0: + p.InsertUsesQueryRow = false + case 1: + if dialect.CanUseLastInsertId() { + columnName := p.AutoColumnNames[0] + field := t.FieldByIndex(p.IndexByColumnName[columnName]) + switch field.Type.Kind() { //nolint:exhaustive // false positive + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + p.InsertUsesQueryRow = false + p.LastInsertIdIsUnsigned = false + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + p.InsertUsesQueryRow = false + p.LastInsertIdIsUnsigned = true + default: + p.InsertUsesQueryRow = true + } + } else { + p.InsertUsesQueryRow = true + } + default: + p.InsertUsesQueryRow = true + } + } + // prepare query strings p.Select = p.buildSelectQueryIfPossible(dialect) p.Insert = p.buildInsertQueryIfPossible(dialect, false) @@ -282,7 +319,7 @@ func (p plan) buildInsertQueryIfPossible(dialect Dialect, isUpsert bool) planned if isUpsert { query += dialect.UpsertClause(p.PrimaryKeyColumnNames, p.getNonPrimaryKeyColumnNames()) } - if len(p.AutoColumnNames) > 0 { + if len(p.AutoColumnNames) > 0 && p.InsertUsesQueryRow { quotedAutoColumns := make([]string, len(p.AutoColumnNames)) for idx, name := range p.AutoColumnNames { quotedAutoColumns[idx] = dialect.QuoteIdentifier(name) diff --git a/plan_test.go b/plan_test.go index f8b4fac..6c42f7a 100644 --- a/plan_test.go +++ b/plan_test.go @@ -74,10 +74,12 @@ func TestQueryConstructionBasic(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.InsertUsesQueryRow, false) + assert.Equal(t, plan.LastInsertIdIsUnsigned, false) assert.Equal(t, plan.Select.Query, "SELECT `ID`, `Description`, `CreatedAt` FROM `basic_records` WHERE ") assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) - assert.Equal(t, plan.Insert.Query, "INSERT INTO `basic_records` (`Description`, `CreatedAt`) VALUES (?, ?) RETURNING `ID`") + assert.Equal(t, plan.Insert.Query, "INSERT INTO `basic_records` (`Description`, `CreatedAt`) VALUES (?, ?)") assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}}) assert.Equal(t, plan.Upsert.Query, "") @@ -96,6 +98,7 @@ func TestQueryConstructionBasic(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.InsertUsesQueryRow, true) assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `) assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) @@ -118,10 +121,12 @@ func TestQueryConstructionBasic(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.InsertUsesQueryRow, false) + assert.Equal(t, plan.LastInsertIdIsUnsigned, false) assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `) assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES (?, ?) RETURNING "ID"`) + assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES (?, ?)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}}) assert.Equal(t, plan.Upsert.Query, "") @@ -151,6 +156,7 @@ func TestQueryConstructionWithOnlyPrimaryKey(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.InsertUsesQueryRow, false) assert.Equal(t, plan.Select.Query, "SELECT `foo_id`, `bar_id` FROM `foo_bar_relations` WHERE ") assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}}) @@ -173,6 +179,7 @@ func TestQueryConstructionWithOnlyPrimaryKey(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.InsertUsesQueryRow, false) assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `) assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}}) @@ -195,6 +202,7 @@ func TestQueryConstructionWithOnlyPrimaryKey(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.InsertUsesQueryRow, false) assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `) assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}}) @@ -227,6 +235,7 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.InsertUsesQueryRow, false) assert.Equal(t, plan.Select.Query, "SELECT `foo_id`, `bar_id` FROM `foo_bar_relations` WHERE ") assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}}) @@ -249,6 +258,7 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.InsertUsesQueryRow, false) assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `) assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}}) @@ -271,6 +281,7 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.InsertUsesQueryRow, false) assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `) assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}}) @@ -342,6 +353,7 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.InsertUsesQueryRow, false) assert.Equal(t, plan.Select.Query, "SELECT `group_id`, `name`, `created_at` FROM `complex_records` WHERE ") assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) @@ -364,6 +376,7 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.InsertUsesQueryRow, false) assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `) assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) @@ -386,6 +399,7 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.InsertUsesQueryRow, false) assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `) assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) @@ -420,6 +434,7 @@ func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.InsertUsesQueryRow, true) assert.Equal(t, plan.Select.Query, "SELECT `id`, `name`, `created_at` FROM `autogenerated_records` WHERE ") assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) @@ -442,6 +457,7 @@ func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.InsertUsesQueryRow, true) assert.Equal(t, plan.Select.Query, `SELECT "id", "name", "created_at" FROM "autogenerated_records" WHERE `) assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) @@ -464,6 +480,7 @@ func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.InsertUsesQueryRow, true) assert.Equal(t, plan.Select.Query, `SELECT "id", "name", "created_at" FROM "autogenerated_records" WHERE `) assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) @@ -5,6 +5,7 @@ package oblast import ( "context" + "database/sql" "fmt" "reflect" @@ -72,7 +73,7 @@ func (s Store[R]) insertUsing(ctx context.Context, stmt handle.Statement, db Han if err != nil { return newIOError(err, "Stmt.Close", stmt.Close()) } - err = insertRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots) + err = insertRecord(ctx, s.plan, v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots) if err != nil { return newIOError(err, "Stmt.Close", stmt.Close()) } @@ -81,7 +82,7 @@ func (s Store[R]) insertUsing(ctx context.Context, stmt handle.Statement, db Han return newIOError(nil, "Stmt.Close", stmt.Close()) } -func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any, scanIndexes [][]int, scanSlots []any) error { +func insertRecord(ctx context.Context, plan plan, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any, scanIndexes [][]int, scanSlots []any) error { for idx, index := range argumentIndexes { argumentSlots[idx] = v.FieldByIndex(index).Interface() } @@ -92,16 +93,38 @@ func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt ha } scanSlots[idx] = f.Addr().Interface() } - var err error - if len(scanSlots) == 0 { + + var ( + result sql.Result + err error + ) + switch { + case len(scanSlots) == 0: _, err = stmt.Exec(ctx, argumentSlots) - } else { - // TODO: using QueryRow for inserting is extremely expensive because database/sql allocates a Rows instance under the hood; other libraries are doing better by limiting themselves to ExecContext() + LastInsertId() + case plan.InsertUsesQueryRow: err = stmt.QueryRow(ctx, argumentSlots, scanSlots) + default: + result, err = stmt.Exec(ctx, argumentSlots) } if err != nil { return fmt.Errorf("while inserting record with idx = %d: %w", recordIndex, err) } + + if result != nil { + id, err := result.LastInsertId() + if err != nil { + return fmt.Errorf("while getting LastInsertId for record with idx = %d: %w", recordIndex, err) + } + if plan.LastInsertIdIsUnsigned { + if id < 0 { + return fmt.Errorf("LastInsertId() = %d for record with idx = %d cannot be converted to uint", id, recordIndex) + } + v.FieldByIndex(scanIndexes[0]).SetUint(uint64(id)) + } else { + v.FieldByIndex(scanIndexes[0]).SetInt(id) + } + } + return nil } @@ -280,7 +303,7 @@ func (s Store[R]) doUpsert(ctx context.Context, db Handle, insertStmt, updateStm } if isInsert { - err = insertRecord(ctx, v, idx, insertStmt, insertArgumentIndexes, insertArgumentSlots, insertScanIndexes, insertScanSlots) + err = insertRecord(ctx, s.plan, v, idx, insertStmt, insertArgumentIndexes, insertArgumentSlots, insertScanIndexes, insertScanSlots) } else { var rowsAffected int64 rowsAffected, err = updateRecord(ctx, v, idx, updateStmt, updateArgumentIndexes, updateArgumentSlots) diff --git a/query_test.go b/query_test.go index a67dade..6013201 100644 --- a/query_test.go +++ b/query_test.go @@ -21,32 +21,93 @@ func TestInsertBasic(t *testing.T) { db := oblast.NewDB(sql.OpenDB(md)) type basicRecord struct { - ID int64 `oblast:"id,auto"` + ID int64 `db:"id,auto"` + Name string `db:"name"` + } + + // testing with the SQLite dialect exercises the Exec()-based codepath + t.Run("driver=sqlite", func(t *testing.T) { + store := oblast.MustNewStore[basicRecord]( + oblast.SqliteDialect(), + oblast.TableNameIs("basic_records"), + oblast.PrimaryKeyIs("id"), + ) + + for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} { + t.Run("N="+strconv.Itoa(batchSize), func(t *testing.T) { + records := make([]*basicRecord, batchSize) + for idx := range batchSize { + records[idx] = &basicRecord{Name: "new"} + md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`). + ExpectExecWithArgs("new"). + AndReturnLastInsertId(int64(42 + idx)) + } + must.Succeed(t, store.Insert(ctx, db, records...)) + for idx, r := range records { + assert.Equal(t, r.ID, int64(42+idx)) + } + }) + } + }) + + // testing with the Postgres dialect exercises the QueryRow()-based codepath + t.Run("driver=postgres", func(t *testing.T) { + store := oblast.MustNewStore[basicRecord]( + oblast.PostgresDialect(), + oblast.TableNameIs("basic_records"), + oblast.PrimaryKeyIs("id"), + ) + + for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} { + t.Run("N="+strconv.Itoa(batchSize), func(t *testing.T) { + records := make([]*basicRecord, batchSize) + for idx := range batchSize { + records[idx] = &basicRecord{Name: "new"} + md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"`). + ExpectQueryWithArgs("new"). + AndReturnColumns("id"). + WithRow(int64(42 + idx)) + } + must.Succeed(t, store.Insert(ctx, db, records...)) + for idx, r := range records { + assert.Equal(t, r.ID, int64(42+idx)) + } + }) + } + }) +} + +func TestInsertWithUintPrimaryKey(t *testing.T) { + ctx := t.Context() + md := mock.NewDriver() + db := oblast.NewDB(sql.OpenDB(md)) + + type exoticRecord struct { + ID uint64 `oblast:"id,auto"` Name string `oblast:"name"` } - store := oblast.MustNewStore[basicRecord]( + store := oblast.MustNewStore[exoticRecord]( oblast.SqliteDialect(), oblast.StructTagKeyIs("oblast"), // this test also randomly provides coverage for this option - oblast.TableNameIs("basic_records"), + oblast.TableNameIs("exotic_records"), oblast.PrimaryKeyIs("id"), ) - for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} { - t.Run("N="+strconv.Itoa(batchSize), func(t *testing.T) { - records := make([]*basicRecord, batchSize) - for idx := range batchSize { - records[idx] = &basicRecord{Name: "new"} - md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?) RETURNING "id"`). - ExpectQueryWithArgs("new"). - AndReturnColumns("id"). - WithRow(int64(42 + idx)) - } - must.Succeed(t, store.Insert(ctx, db, records...)) - for idx, r := range records { - assert.Equal(t, r.ID, int64(42+idx)) - } - }) - } + // success case: positive ID fits into uint64 + md.ForQuery(`INSERT INTO "exotic_records" ("name") VALUES (?)`). + ExpectExecWithArgs("new"). + AndReturnLastInsertId(42) + record := exoticRecord{Name: "new"} + must.Succeed(t, store.Insert(ctx, db, &record)) + assert.Equal(t, record.ID, 42) + + // error case: negative ID cannot be converted to uint64 + md.ForQuery(`INSERT INTO "exotic_records" ("name") VALUES (?)`). + ExpectExecWithArgs("another"). + AndReturnLastInsertId(-42) + record = exoticRecord{Name: "another"} + err := store.Insert(ctx, db, &record) + assert.ErrEqual(t, err, "LastInsertId() = -42 for record with idx = 0 cannot be converted to uint") } func TestUpdateBasic(t *testing.T) { @@ -124,17 +185,15 @@ func TestUpsertBasicWithAutoColumn(t *testing.T) { oblast.PrimaryKeyIs("id"), ) - md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?) RETURNING "id"`). - ExpectQueryWithArgs("first needs insert"). - AndReturnColumns("id"). - WithRow(int64(1)) + md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`). + ExpectExecWithArgs("first needs insert"). + AndReturnLastInsertId(1) md.ForQuery(`UPDATE "basic_records" SET "name" = ? WHERE "id" = ?`). ExpectExecWithArgs("second needs update", 2). AndReturnRowsAffected(1) - md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?) RETURNING "id"`). - ExpectQueryWithArgs("third needs insert"). - AndReturnColumns("id"). - WithRow(int64(3)) + md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`). + ExpectExecWithArgs("third needs insert"). + AndReturnLastInsertId(3) md.ForQuery(`UPDATE "basic_records" SET "name" = ? WHERE "id" = ?`). ExpectExecWithArgs("fourth needs update", 4). AndReturnRowsAffected(1) @@ -208,7 +267,7 @@ func TestWriteQueriesFailDuringPrepare(t *testing.T) { } err := store.Insert(ctx, db, recordsForInsert...) - baseError := `unexpected query: INSERT INTO "basic_records" ("name") VALUES (?) RETURNING "id"` + baseError := `unexpected query: INSERT INTO "basic_records" ("name") VALUES (?)` if batchSize < oblast.PrepareThreshold { assert.ErrEqual(t, err, "while inserting record with idx = 0: "+baseError) } else { @@ -283,10 +342,6 @@ func TestInsertFailsOnFilledAutoField(t *testing.T) { oblast.PrimaryKeyIs("id"), ) - md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?) RETURNING "id"`). - ExpectQueryWithArgs("existing"). - AndReturnColumns("id"). - WithRow(42) err := store.Insert(ctx, db, &basicRecord{ID: 23, Name: "third"}) assert.ErrEqual(t, err, `refusing to INSERT record with idx = 0 that already has non-zero values in its "auto" columns`) } @@ -394,6 +449,18 @@ func TestUninitializedTransparentPointerStructs(t *testing.T) { err = nestedRecordStore.Upsert(ctx, db, &freshBrokenRecord) assert.ErrEqual(t, err, `refusing to INSERT or UPDATE record with idx = 0: cannot access all mapped fields because field "timestamps" holds a nil pointer`) + // check success case on INSERT + now := time.Now() + freshIntactRecord := nestedRecord{ + Name: "foo", + timestamps: ×tamps{CreatedAt: now, DeletedAt: nil}, + } + md.ForQuery(`INSERT INTO "nested_records" ("name", "created_at", "deleted_at") VALUES (?, ?, ?)`). + ExpectExecWithArgs("foo", now, (*time.Time)(nil)). + AndReturnLastInsertId(1) + must.Succeed(t, nestedRecordStore.Insert(ctx, db, &freshIntactRecord)) + assert.Equal(t, freshIntactRecord.ID, 1) + // check detection on UPDATE existingBrokenRecord := nestedRecord{ ID: 42, @@ -405,6 +472,18 @@ func TestUninitializedTransparentPointerStructs(t *testing.T) { err = nestedRecordStore.Upsert(ctx, db, &freshBrokenRecord) assert.ErrEqual(t, err, `refusing to INSERT or UPDATE record with idx = 0: cannot access all mapped fields because field "timestamps" holds a nil pointer`) + // check success case on UPDATE + now = time.Now() + existingIntactRecord := nestedRecord{ + ID: 42, + Name: "bar", + timestamps: ×tamps{CreatedAt: now, DeletedAt: nil}, + } + md.ForQuery(`UPDATE "nested_records" SET "name" = ?, "created_at" = ?, "deleted_at" = ? WHERE "id" = ?`). + ExpectExecWithArgs("bar", now, (*time.Time)(nil), 42). + AndReturnRowsAffected(1) + must.Succeed(t, nestedRecordStore.Update(ctx, db, existingIntactRecord)) + // check that detection on DELETE does not care about transparent pointer structs as long as they do not contain PK fields md.ForQuery(`DELETE FROM "nested_records" WHERE "id" = ?`). ExpectExecWithArgs(42). |
