From 0843d2cda294bb67d6e65d585c6fd63807d70619 Mon Sep 17 00:00:00 2001 From: Stefan Majewsky Date: Sat, 18 Apr 2026 16:00:52 +0200 Subject: fix Store.Insert() failing on tables without auto columns --- query.go | 38 +++++++++++++++++++++++--------------- query_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 15 deletions(-) diff --git a/query.go b/query.go index 04ba647..1392dfb 100644 --- a/query.go +++ b/query.go @@ -81,7 +81,7 @@ func (s preparedStatement) QueryRow(args ...any) *sql.Row { func (s Store[R]) Insert(db Handle, records ...R) ([]R, error) { // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. // Any expression that does not depend on type R should be factored out into a reusable function. - if s.dialect.UsesLastInsertID() { + if s.dialect.UsesLastInsertID() || len(s.plan.Insert.ScanIndexes) == 0 { return s.insertUsingLastInsertID(db, records) } else { return s.insertUsingReturningClause(db, records) @@ -103,8 +103,11 @@ func (s Store[R]) insertUsingLastInsertID(db Handle, records []R) (returnedRecor var ( argumentIndexes = s.plan.Insert.ArgumentIndexes argumentSlots = make([]any, len(argumentIndexes)) - scanIndex = s.plan.Insert.ScanIndexes[0] + scanIndex []int ) + if len(s.plan.Insert.ScanIndexes) > 0 { + scanIndex = s.plan.Insert.ScanIndexes[0] + } for idx := range records { v := reflect.ValueOf(&records[idx]).Elem() err := insertRecordUsingLastInsertID(v, idx, stmt, argumentIndexes, argumentSlots, scanIndex, s.plan) @@ -120,26 +123,31 @@ func insertRecordUsingLastInsertID(v reflect.Value, recordIndex int, stmt prepar for idx, index := range argumentIndexes { argumentSlots[idx] = v.FieldByIndex(index).Interface() } - scanField := v.FieldByIndex(scanIndex) - if !scanField.IsZero() { - return fmt.Errorf(`refusing to INSERT record with idx = %d that already has non-zero values in its "auto" columns`, recordIndex) + var scanField reflect.Value + if scanIndex != nil { + scanField = v.FieldByIndex(scanIndex) + if !scanField.IsZero() { + return fmt.Errorf(`refusing to INSERT record with idx = %d that already has non-zero values in its "auto" columns`, recordIndex) + } } result, err := stmt.Exec(argumentSlots...) if err != nil { return fmt.Errorf("during Exec() for record with idx = %d: %w", recordIndex, err) } - id, err := result.LastInsertId() - if err != nil { - return fmt.Errorf("during LastInsertId() for record with idx = %d: %w", recordIndex, err) - } - if plan.FillIDWithSetInt { - scanField.SetInt(id) - } else if plan.FillIDWithSetUint { - if id < 0 { - return fmt.Errorf("LastInsertId() = %d for record with idx = %d cannot be converted to uint", id, recordIndex) + if scanIndex != nil { + id, err := result.LastInsertId() + if err != nil { + return fmt.Errorf("during LastInsertId() for record with idx = %d: %w", recordIndex, err) + } + if plan.FillIDWithSetInt { + scanField.SetInt(id) + } else if plan.FillIDWithSetUint { + if id < 0 { + return fmt.Errorf("LastInsertId() = %d for record with idx = %d cannot be converted to uint", id, recordIndex) + } + scanField.SetUint(uint64(id)) } - scanField.SetUint(uint64(id)) } return nil } diff --git a/query_test.go b/query_test.go index 000c385..7dde757 100644 --- a/query_test.go +++ b/query_test.go @@ -328,3 +328,48 @@ func TestInsertWithUnsignedIdField(t *testing.T) { assert.ErrEqual(t, err, `refusing to INSERT record with idx = 0 that already has non-zero values in its "auto" columns`) }) } + +func TestInsertWithoutAutoColumns(t *testing.T) { + md := mock.NewDriver() + db := sql.OpenDB(md) + + type relation struct { + FooID int64 `db:"foo_id"` + BarID int64 `db:"bar_id"` + } + + // Even in dialects using RETURNING clause, this uses Exec() because there is nothing to return. + // Therefore, the test behavior with both dialects is identical except for the different placeholder syntax in the query. + runTest := func(store oblast.Store[relation], query string) { + md.ForQuery(query). + ExpectExecWithArgs(1, 2). + AndReturnRowsAffected(1) + md.ForQuery(query). + ExpectExecWithArgs(1, 3). + AndReturnRowsAffected(1) + relations := []relation{ + {FooID: 1, BarID: 2}, + {FooID: 1, BarID: 3}, + } + insertedRelations := must.Return(store.Insert(db, relations...))(t) + assert.SliceEqual(t, insertedRelations, relations...) + } + + t.Run("in dialect using LastInsertID", func(t *testing.T) { + store := oblast.MustNewStore[relation]( + oblast.SqliteDialect(), + oblast.TableNameIs("foo_bar_relations"), + oblast.PrimaryKeyIs("foo_id", "bar_id"), + ) + runTest(store, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?)`) + }) + + t.Run("in dialect using RETURNING clause", func(t *testing.T) { + store := oblast.MustNewStore[relation]( + oblast.PostgresDialect(), + oblast.TableNameIs("foo_bar_relations"), + oblast.PrimaryKeyIs("foo_id", "bar_id"), + ) + runTest(store, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES ($1, $2)`) + }) +} -- cgit v1.2.3