diff options
| -rw-r--r-- | CHANGELOG.md | 11 | ||||
| -rw-r--r-- | dialect.go | 51 | ||||
| -rw-r--r-- | plan.go | 38 | ||||
| -rw-r--r-- | plan_test.go | 76 | ||||
| -rw-r--r-- | query.go | 91 | ||||
| -rw-r--r-- | query_test.go | 193 |
6 files changed, 122 insertions, 338 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md index e5e73eb..8d7c623 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,17 @@ SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net> SPDX-License-Identifier: Apache-2.0 --> +# v0.3.0 (TBD) + +Changes: + +- `Store.Insert()` now takes its arguments by-pointer. This is probably slightly less efficient, + but significantly safer because autogenerated field values cannot be disregarded by accident. +- Removed support for SQL dialects that rely on LastInsertId() for ID columns. + Using a RETURNING clause to collect autogenerated field values is objectively better in every way, + and has been supported by both MariaDB and SQLite for at least six years. + In practice, this only drops support specifically for Oracle MySQL. + # v0.2.0 (2026-04-18) Changes: @@ -27,18 +27,6 @@ type Dialect interface { // in order to avoid the name from being interpreted as a keyword. QuoteIdentifier(name string) string - // UsesLastInsertID returns whether values for auto-generated columns are - // collected from LastInsertID(). If false, the INSERT query must instead - // yield a result row containing the values. - UsesLastInsertID() bool - - // InsertSuffixForAutoColumns is appended to `INSERT (...) VALUES (...)` - // statements to collect values for auto-filled columns. - // - // If UsesLastInsertID is true, this is usually not needed and the empty - // string can be returned. - InsertSuffixForAutoColumns(columns []string) string - // UpsertClause generates an "ON CONFLICT" or similar clause // that can be appended to an INSERT query to make it fall back to // behave like UPDATE if a record with the same primary key already exists. @@ -46,19 +34,20 @@ type Dialect interface { UpsertClause(pkColumns, otherColumns []string) string } -// MysqlDialect is the dialect of MySQL and MariaDB databases. -func MysqlDialect() Dialect { - return mysqlDialect{} +// MariaDBDialect is the dialect of MariaDB 10.5+ databases. +// +// This dialect does NOT support MySQL, as well as ancient MariaDB versions (10.5 was released 2020-06-24), +// because those do not understand the "INSERT ... RETURNING" syntax. +func MariaDBDialect() Dialect { + return mariadbDialect{} } -type mysqlDialect struct{} +type mariadbDialect struct{} -func (mysqlDialect) Placeholder(_ int) string { return "?" } -func (mysqlDialect) QuoteIdentifier(name string) string { return "`" + name + "`" } -func (mysqlDialect) UsesLastInsertID() bool { return true } -func (mysqlDialect) InsertSuffixForAutoColumns(columns []string) string { return "" } +func (mariadbDialect) Placeholder(_ int) string { return "?" } +func (mariadbDialect) QuoteIdentifier(name string) string { return "`" + name + "`" } -func (d mysqlDialect) UpsertClause(pkColumns, otherColumns []string) string { +func (d mariadbDialect) UpsertClause(pkColumns, otherColumns []string) string { clauses := make([]string, max(1, len(otherColumns))) if len(otherColumns) == 0 { // we need at least one UPDATE clause; if there are no non-PK columns, @@ -81,15 +70,6 @@ type postgresDialect struct{} func (postgresDialect) Placeholder(i int) string { return "$" + strconv.Itoa(i+1) } func (postgresDialect) QuoteIdentifier(name string) string { return `"` + name + `"` } -func (postgresDialect) UsesLastInsertID() bool { return false } - -func (d postgresDialect) InsertSuffixForAutoColumns(columns []string) string { - quotedColumns := make([]string, len(columns)) - for idx, name := range columns { - quotedColumns[idx] = d.QuoteIdentifier(name) - } - return ` RETURNING ` + strings.Join(quotedColumns, ", ") -} func (d postgresDialect) UpsertClause(pkColumns, otherColumns []string) string { quotedPkColumns := make([]string, len(pkColumns)) @@ -108,17 +88,18 @@ func (d postgresDialect) UpsertClause(pkColumns, otherColumns []string) string { } } -// SqliteDialect is the dialect of SQLite databases. +// SqliteDialect is the dialect of SQLite 3.24.0+ databases. +// +// This dialect does NOT support ancient SQLite versions (3.24.0 was released 2018-06-04) +// that do not understand the "INSERT ... RETURNING" syntax. func SqliteDialect() Dialect { return sqliteDialect{} } type sqliteDialect struct{} -func (sqliteDialect) Placeholder(_ int) string { return "?" } -func (sqliteDialect) QuoteIdentifier(name string) string { return `"` + name + `"` } -func (sqliteDialect) UsesLastInsertID() bool { return true } -func (sqliteDialect) InsertSuffixForAutoColumns(columns []string) string { return "" } +func (sqliteDialect) Placeholder(_ int) string { return "?" } +func (sqliteDialect) QuoteIdentifier(name string) string { return `"` + name + `"` } func (sqliteDialect) UpsertClause(pkColumns, otherColumns []string) string { return postgresDialect{}.UpsertClause(pkColumns, otherColumns) } @@ -26,10 +26,6 @@ type plan struct { // Indexes of pointer-typed fields that need to be initialized before scanning into this type. IndexesOfTransparentPointerStructs [][]int - // In dialects with UsesLastInsertID() == true, whether the ID column must be written with reflect.Value.SetInt() or reflect.Value.SetUint(). - FillIDWithSetUint bool - FillIDWithSetInt bool - // Planned queries. Select plannedQuery // only `SELECT ... FROM ... WHERE `; user supplies the rest during Select{,One}Where() Insert plannedQuery @@ -171,32 +167,6 @@ func buildPlan(t reflect.Type, dialect Dialect, opts planOpts) (plan, error) { } } - // validation: LastInsertID() only works if at most one column is auto-filled, and if that column holds an integer type - if dialect.UsesLastInsertID() { - switch len(p.AutoColumnNames) { - case 0: - // nothing to check - case 1: - columnName := p.AutoColumnNames[0] - field := t.FieldByIndex(p.IndexByColumnName[columnName]) - switch field.Type.Kind() { //nolint:exhaustive // false positive - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - p.FillIDWithSetInt = true - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - p.FillIDWithSetUint = true - default: - return plan{}, fmt.Errorf( - "column %q is marked as auto-filled, but this SQL dialect only supports auto-filling struct fields with integer types", - columnName) - } - default: - return plan{}, fmt.Errorf( - "multiple columns are marked as auto-filled (%s), but this SQL dialect only supports at most one per table", - strings.Join(p.AutoColumnNames, ", "), - ) - } - } - // prepare query strings p.Select = p.buildSelectQueryIfPossible(dialect) p.Insert = p.buildInsertQueryIfPossible(dialect, false) @@ -277,8 +247,6 @@ func (p plan) buildInsertQueryIfPossible(dialect Dialect, isUpsert bool) planned quotedPlaceholders[idx] = dialect.Placeholder(idx) } if len(p.AutoColumnNames) > 0 { - // NOTE: This is filled even if dialect.UsesLastInsertID() is false. - // We need this index to find the right value on which to run SetInt() or SetUint(). scanIndexes = make([][]int, len(p.AutoColumnNames)) for idx, columnName := range p.AutoColumnNames { scanIndexes[idx] = p.IndexByColumnName[columnName] @@ -295,7 +263,11 @@ func (p plan) buildInsertQueryIfPossible(dialect Dialect, isUpsert bool) planned query += dialect.UpsertClause(p.PrimaryKeyColumnNames, p.getNonPrimaryKeyColumnNames()) } if len(p.AutoColumnNames) > 0 { - query += dialect.InsertSuffixForAutoColumns(p.AutoColumnNames) + quotedAutoColumns := make([]string, len(p.AutoColumnNames)) + for idx, name := range p.AutoColumnNames { + quotedAutoColumns[idx] = dialect.QuoteIdentifier(name) + } + query += ` RETURNING ` + strings.Join(quotedAutoColumns, ", ") } return plannedQuery{query, argumentIndexes, scanIndexes} } diff --git a/plan_test.go b/plan_test.go index 08c7252..b3eeac5 100644 --- a/plan_test.go +++ b/plan_test.go @@ -69,15 +69,15 @@ func TestQueryConstructionBasic(t *testing.T) { PrimaryKeyColumnNames: []string{"ID"}, } - t.Run("MysqlDialect", func(t *testing.T) { - plan, err := buildPlan(reflect.TypeFor[record](), MysqlDialect(), opts) + t.Run("MariaDBDialect", func(t *testing.T) { + plan, err := buildPlan(reflect.TypeFor[record](), MariaDBDialect(), opts) if err != nil { t.Error(err) } assert.Equal(t, plan.Select.Query, "SELECT `ID`, `Description`, `CreatedAt` FROM `basic_records` WHERE ") assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) - assert.Equal(t, plan.Insert.Query, "INSERT INTO `basic_records` (`Description`, `CreatedAt`) VALUES (?, ?)") + assert.Equal(t, plan.Insert.Query, "INSERT INTO `basic_records` (`Description`, `CreatedAt`) VALUES (?, ?) RETURNING `ID`") assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}}) assert.Equal(t, plan.Upsert.Query, "") @@ -121,7 +121,7 @@ func TestQueryConstructionBasic(t *testing.T) { assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `) assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) - assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES (?, ?)`) + assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES (?, ?) RETURNING "ID"`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}}) assert.Equal(t, plan.Upsert.Query, "") @@ -146,8 +146,8 @@ func TestQueryConstructionWithOnlyPrimaryKey(t *testing.T) { PrimaryKeyColumnNames: []string{"foo_id", "bar_id"}, } - t.Run("MysqlDialect", func(t *testing.T) { - plan, err := buildPlan(reflect.TypeFor[relation](), MysqlDialect(), opts) + t.Run("MariaDBDialect", func(t *testing.T) { + plan, err := buildPlan(reflect.TypeFor[relation](), MariaDBDialect(), opts) if err != nil { t.Error(err) } @@ -222,8 +222,8 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) { TableName: "foo_bar_relations", } - t.Run("MysqlDialect", func(t *testing.T) { - plan, err := buildPlan(reflect.TypeFor[relation](), MysqlDialect(), opts) + t.Run("MariaDBDialect", func(t *testing.T) { + plan, err := buildPlan(reflect.TypeFor[relation](), MariaDBDialect(), opts) if err != nil { t.Error(err) } @@ -321,7 +321,7 @@ func TestQueryConstructionImpossble(t *testing.T) { } } - t.Run("MysqlDialect", testWith(MysqlDialect())) + t.Run("MariaDBDialect", testWith(MariaDBDialect())) t.Run("PostgresDialect", testWith(PostgresDialect())) t.Run("SqliteDialect", testWith(SqliteDialect())) } @@ -337,8 +337,8 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) { PrimaryKeyColumnNames: []string{"group_id", "name"}, } - t.Run("MysqlDialect", func(t *testing.T) { - plan, err := buildPlan(reflect.TypeFor[record](), MysqlDialect(), opts) + t.Run("MariaDBDialect", func(t *testing.T) { + plan, err := buildPlan(reflect.TypeFor[record](), MariaDBDialect(), opts) if err != nil { t.Error(err) } @@ -415,9 +415,26 @@ func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) { PrimaryKeyColumnNames: []string{"id"}, } - t.Run("MysqlDialect", func(t *testing.T) { - _, err := NewStore[record](MysqlDialect()) - assert.Equal(t, err.Error(), `cannot use type oblast.record for queries: multiple columns are marked as auto-filled (id, created_at), but this SQL dialect only supports at most one per table`) + t.Run("MariaDBDialect", func(t *testing.T) { + plan, err := buildPlan(reflect.TypeFor[record](), MariaDBDialect(), opts) + if err != nil { + t.Error(err) + } + assert.Equal(t, plan.Select.Query, "SELECT `id`, `name`, `created_at` FROM `autogenerated_records` WHERE ") + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) + assert.Equal(t, plan.Insert.Query, "INSERT INTO `autogenerated_records` (`name`) VALUES (?) RETURNING `id`, `created_at`") + assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}, {2}}) + assert.Equal(t, plan.Upsert.Query, "") + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) + assert.Equal(t, plan.Update.Query, "UPDATE `autogenerated_records` SET `name` = ?, `created_at` = ? WHERE `id` = ?") + assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) + assert.Equal(t, plan.Delete.Query, "DELETE FROM `autogenerated_records` WHERE `id` = ?") + assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}}) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) }) t.Run("PostgresDialect", func(t *testing.T) { @@ -443,8 +460,25 @@ func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) { }) t.Run("SqliteDialect", func(t *testing.T) { - _, err := NewStore[record](SqliteDialect()) - assert.Equal(t, err.Error(), `cannot use type oblast.record for queries: multiple columns are marked as auto-filled (id, created_at), but this SQL dialect only supports at most one per table`) + plan, err := buildPlan(reflect.TypeFor[record](), SqliteDialect(), opts) + if err != nil { + t.Error(err) + } + assert.Equal(t, plan.Select.Query, `SELECT "id", "name", "created_at" FROM "autogenerated_records" WHERE `) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Select.ScanIndexes, [][]int{{0}, {1}, {2}}) + assert.Equal(t, plan.Insert.Query, `INSERT INTO "autogenerated_records" ("name") VALUES (?) RETURNING "id", "created_at"`) + assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}}) + assert.DeepEqual(t, plan.Insert.ScanIndexes, [][]int{{0}, {2}}) + assert.Equal(t, plan.Upsert.Query, "") + assert.DeepEqual(t, plan.Upsert.ArgumentIndexes, nil) + assert.DeepEqual(t, plan.Upsert.ScanIndexes, nil) + assert.Equal(t, plan.Update.Query, `UPDATE "autogenerated_records" SET "name" = ?, "created_at" = ? WHERE "id" = ?`) + assert.DeepEqual(t, plan.Update.ArgumentIndexes, [][]int{{1}, {2}, {0}}) + assert.DeepEqual(t, plan.Update.ScanIndexes, nil) + assert.Equal(t, plan.Delete.Query, `DELETE FROM "autogenerated_records" WHERE "id" = ?`) + assert.DeepEqual(t, plan.Delete.ArgumentIndexes, [][]int{{0}}) + assert.DeepEqual(t, plan.Delete.ScanIndexes, nil) }) } @@ -495,16 +529,6 @@ func TestPlanErrorCases(t *testing.T) { assert.Equal(t, err.Error(), `cannot use type oblast.recordWithUnknownPK for queries: `+ "no field has tag `db:\"record_id\"`, but a field of this name was declared in the primary key") - type recordWithNonintegerAutoKey struct { - CreatedAt time.Time `db:"created_at,auto"` - Name string `db:"name"` - } - _, err = NewStore[recordWithNonintegerAutoKey](SqliteDialect(), - TableNameIs("records"), - ) - assert.Equal(t, err.Error(), `cannot use type oblast.recordWithNonintegerAutoKey for queries: `+ - `column "created_at" is marked as auto-filled, but this SQL dialect only supports auto-filling struct fields with integer types`) - type recordWithWeirdTagOption struct { ID int64 `db:",auto"` Name string `db:",unique"` @@ -80,80 +80,6 @@ func (s preparedStatement) QueryRow(args ...any) *sql.Row { func (s Store[R]) Insert(db Handle, records ...*R) error { // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. // Any expression that does not depend on type R should be factored out into a reusable function. - if s.dialect.UsesLastInsertID() || len(s.plan.Insert.ScanIndexes) == 0 { - return s.insertUsingLastInsertID(db, records) - } else { - return s.insertUsingReturningClause(db, records) - } -} - -func (s Store[R]) insertUsingLastInsertID(db Handle, records []*R) (returnedError error) { - // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. - // Any expression that does not depend on type R should be factored out into a reusable function. - - stmt, err := prepare(db, s.plan.Insert.Query, "Insert", len(records)) - if err != nil { - return err - } - defer func() { - returnedError = newIOError(returnedError, "Stmt.Close", stmt.Close()) - }() - - var ( - argumentIndexes = s.plan.Insert.ArgumentIndexes - argumentSlots = make([]any, len(argumentIndexes)) - scanIndex []int - ) - if len(s.plan.Insert.ScanIndexes) > 0 { - scanIndex = s.plan.Insert.ScanIndexes[0] - } - for idx := range records { - v := reflect.ValueOf(records[idx]).Elem() - err := insertRecordUsingLastInsertID(v, idx, stmt, argumentIndexes, argumentSlots, scanIndex, s.plan) - if err != nil { - return newIOError(err, "Stmt.Close", stmt.Close()) - } - } - - return newIOError(nil, "Stmt.Close", stmt.Close()) -} - -func insertRecordUsingLastInsertID(v reflect.Value, recordIndex int, stmt preparedStatement, argumentIndexes [][]int, argumentSlots []any, scanIndex []int, plan plan) error { - for idx, index := range argumentIndexes { - argumentSlots[idx] = v.FieldByIndex(index).Interface() - } - var scanField reflect.Value - if scanIndex != nil { - scanField = v.FieldByIndex(scanIndex) - if !scanField.IsZero() { - return fmt.Errorf(`refusing to INSERT record with idx = %d that already has non-zero values in its "auto" columns`, recordIndex) - } - } - - result, err := stmt.Exec(argumentSlots...) - if err != nil { - return fmt.Errorf("during Exec() for record with idx = %d: %w", recordIndex, err) - } - if scanIndex != nil { - id, err := result.LastInsertId() - if err != nil { - return fmt.Errorf("during LastInsertId() for record with idx = %d: %w", recordIndex, err) - } - if plan.FillIDWithSetInt { - scanField.SetInt(id) - } else if plan.FillIDWithSetUint { - if id < 0 { - return fmt.Errorf("LastInsertId() = %d for record with idx = %d cannot be converted to uint", id, recordIndex) - } - scanField.SetUint(uint64(id)) - } - } - return nil -} - -func (s Store[R]) insertUsingReturningClause(db Handle, records []*R) (returnedError error) { - // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. - // Any expression that does not depend on type R should be factored out into a reusable function. stmt, err := prepare(db, s.plan.Insert.Query, "Insert", len(records)) if err != nil { @@ -169,7 +95,7 @@ func (s Store[R]) insertUsingReturningClause(db Handle, records []*R) (returnedE for idx := range records { v := reflect.ValueOf(records[idx]).Elem() - err := insertRecordUsingReturningClause(v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots) + err := insertRecord(v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots) if err != nil { return newIOError(err, "Stmt.Close", stmt.Close()) } @@ -178,7 +104,7 @@ func (s Store[R]) insertUsingReturningClause(db Handle, records []*R) (returnedE return newIOError(nil, "Stmt.Close", stmt.Close()) } -func insertRecordUsingReturningClause(v reflect.Value, recordIndex int, stmt preparedStatement, argumentIndexes [][]int, argumentSlots []any, scanIndexes [][]int, scanSlots []any) error { +func insertRecord(v reflect.Value, recordIndex int, stmt preparedStatement, argumentIndexes [][]int, argumentSlots []any, scanIndexes [][]int, scanSlots []any) error { for idx, index := range argumentIndexes { argumentSlots[idx] = v.FieldByIndex(index).Interface() } @@ -189,9 +115,14 @@ func insertRecordUsingReturningClause(v reflect.Value, recordIndex int, stmt pre } scanSlots[idx] = f.Addr().Interface() } - err := stmt.QueryRow(argumentSlots...).Scan(scanSlots...) + var err error + if len(scanSlots) == 0 { + _, err = stmt.Exec(argumentSlots...) + } else { + err = stmt.QueryRow(argumentSlots...).Scan(scanSlots...) + } if err != nil { - return fmt.Errorf("during QueryRow() for record with idx = %d: %w", recordIndex, err) + return fmt.Errorf("while inserting record with idx = %d: %w", recordIndex, err) } return nil } @@ -233,7 +164,7 @@ func updateRecord(v reflect.Value, recordIndex int, stmt preparedStatement, argu } result, err := stmt.Exec(argumentSlots...) if err != nil { - return 0, fmt.Errorf("during Exec() for record with idx = %d: %w", recordIndex, err) + return 0, fmt.Errorf("while updating record with idx = %d: %w", recordIndex, err) } rowsAffected, err := result.RowsAffected() if err != nil { @@ -276,7 +207,7 @@ func deleteRecord(v reflect.Value, recordIndex int, stmt preparedStatement, argu } _, err := stmt.Exec(argumentSlots...) if err != nil { - return fmt.Errorf("during Exec() for record with idx = %d: %w", recordIndex, err) + return fmt.Errorf("while deleting record with idx = %d: %w", recordIndex, err) } return nil } diff --git a/query_test.go b/query_test.go index ae60db9..29cb015 100644 --- a/query_test.go +++ b/query_test.go @@ -14,7 +14,7 @@ import ( "go.xyrillian.de/oblast/internal/must" ) -func TestInsertBasicUsingLastInsertId(t *testing.T) { +func TestInsertBasic(t *testing.T) { md := mock.NewDriver() db := sql.OpenDB(md) @@ -34,39 +34,7 @@ func TestInsertBasicUsingLastInsertId(t *testing.T) { records := make([]*basicRecord, batchSize) for idx := range batchSize { records[idx] = &basicRecord{Name: "new"} - md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`). - ExpectExecWithArgs("new"). - AndReturnLastInsertId(int64(42 + idx)). - AndReturnRowsAffected(1) - } - must.Succeed(t, store.Insert(db, records...)) - for idx, r := range records { - assert.Equal(t, r.ID, int64(42+idx)) - } - }) - } -} - -func TestInsertBasicUsingReturningClause(t *testing.T) { - md := mock.NewDriver() - db := sql.OpenDB(md) - - type basicRecord struct { - ID int64 `db:"id,auto"` - Name string `db:"name"` - } - store := oblast.MustNewStore[basicRecord]( - oblast.PostgresDialect(), - oblast.TableNameIs("basic_records"), - oblast.PrimaryKeyIs("id"), - ) - - for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} { - t.Run("N="+strconv.Itoa(batchSize), func(t *testing.T) { - records := make([]*basicRecord, batchSize) - for idx := range batchSize { - records[idx] = &basicRecord{Name: "new"} - md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"`). + md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?) RETURNING "id"`). ExpectQueryWithArgs("new"). AndReturnColumns("id"). WithRow(int64(42 + idx)) @@ -185,9 +153,9 @@ func TestWriteQueriesFailDuringPrepare(t *testing.T) { } err := store.Insert(db, recordsForInsert...) - baseError := `unexpected query: INSERT INTO "basic_records" ("name") VALUES (?)` + baseError := `unexpected query: INSERT INTO "basic_records" ("name") VALUES (?) RETURNING "id"` if batchSize < oblast.PrepareThreshold { - assert.ErrEqual(t, err, "during Exec() for record with idx = 0: "+baseError) + assert.ErrEqual(t, err, "while inserting record with idx = 0: "+baseError) } else { assert.ErrEqual(t, err, "during Prepare(): "+baseError) } @@ -195,7 +163,7 @@ func TestWriteQueriesFailDuringPrepare(t *testing.T) { err = store.Update(db, records...) baseError = `unexpected query: UPDATE "basic_records" SET "name" = ? WHERE "id" = ?` if batchSize < oblast.PrepareThreshold { - assert.ErrEqual(t, err, "during Exec() for record with idx = 0: "+baseError) + assert.ErrEqual(t, err, "while updating record with idx = 0: "+baseError) } else { assert.ErrEqual(t, err, "during Prepare(): "+baseError) } @@ -203,28 +171,7 @@ func TestWriteQueriesFailDuringPrepare(t *testing.T) { err = store.Delete(db, records...) baseError = `unexpected query: DELETE FROM "basic_records" WHERE "id" = ?` if batchSize < oblast.PrepareThreshold { - assert.ErrEqual(t, err, "during Exec() for record with idx = 0: "+baseError) - } else { - assert.ErrEqual(t, err, "during Prepare(): "+baseError) - } - } - - store = oblast.MustNewStore[basicRecord]( - oblast.PostgresDialect(), // for test coverage of insertUsingReturningClause() - oblast.TableNameIs("basic_records"), - oblast.PrimaryKeyIs("id"), - ) - - for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} { - recordsForInsert := make([]*basicRecord, batchSize) - for idx := range batchSize { - recordsForInsert[idx] = &basicRecord{Name: "foo"} - } - - err := store.Insert(db, recordsForInsert...) - baseError := `unexpected query: INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"` - if batchSize < oblast.PrepareThreshold { - assert.ErrEqual(t, err, "during QueryRow() for record with idx = 0: "+baseError) + assert.ErrEqual(t, err, "while deleting record with idx = 0: "+baseError) } else { assert.ErrEqual(t, err, "during Prepare(): "+baseError) } @@ -254,83 +201,29 @@ func TestUpdateFailsOnMissingRecord(t *testing.T) { assert.Equal(t, hasCorrectType, true) } -func TestInsertWithUnsignedIdField(t *testing.T) { +func TestInsertFailsOnFilledAutoField(t *testing.T) { md := mock.NewDriver() db := sql.OpenDB(md) type basicRecord struct { - ID uint64 `db:"id,auto"` // not int64! + ID int64 `db:"id,auto"` Name string `db:"name"` } + store := oblast.MustNewStore[basicRecord]( + oblast.SqliteDialect(), + oblast.TableNameIs("basic_records"), + oblast.PrimaryKeyIs("id"), + ) - t.Run("using LastInsertID", func(t *testing.T) { - store := oblast.MustNewStore[basicRecord]( - oblast.SqliteDialect(), - oblast.TableNameIs("basic_records"), - oblast.PrimaryKeyIs("id"), - ) - - // success case - md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`). - ExpectExecWithArgs("first"). - AndReturnLastInsertId(42). - AndReturnRowsAffected(1) - record := basicRecord{Name: "first"} - must.Succeed(t, store.Insert(db, &record)) - assert.Equal(t, record, basicRecord{ID: 42, Name: "first"}) - - // error case: negative ID cannot be cast to uint64 - md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`). - ExpectExecWithArgs("second"). - AndReturnLastInsertId(-42). - AndReturnRowsAffected(1) - err := store.Insert(db, &basicRecord{Name: "second"}) - assert.ErrEqual(t, err, "LastInsertId() = -42 for record with idx = 0 cannot be converted to uint") - - // error case: cannot Insert() a record that already has its ID field filled - md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`). - ExpectExecWithArgs("third"). - AndReturnLastInsertId(42). - AndReturnRowsAffected(1) - err = store.Insert(db, &basicRecord{ID: 23, Name: "third"}) - assert.ErrEqual(t, err, `refusing to INSERT record with idx = 0 that already has non-zero values in its "auto" columns`) - }) - - t.Run("using RETURNING clause", func(t *testing.T) { - store := oblast.MustNewStore[basicRecord]( - oblast.PostgresDialect(), - oblast.TableNameIs("basic_records"), - oblast.PrimaryKeyIs("id"), - ) - - // success case - md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"`). - ExpectQueryWithArgs("first"). - AndReturnColumns("id"). - WithRow(42) - record := basicRecord{Name: "first"} - must.Succeed(t, store.Insert(db, &record)) - assert.Equal(t, record, basicRecord{ID: 42, Name: "first"}) - - // error case: negative ID cannot be cast to uint64 - md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"`). - ExpectQueryWithArgs("second"). - AndReturnColumns("id"). - WithRow(-42) - err := store.Insert(db, &basicRecord{Name: "second"}) - assert.ErrEqual(t, err, `during QueryRow() for record with idx = 0: sql: Scan error on column index 0, name "id": converting driver.Value type int ("-42") to a uint64: invalid syntax`) - - // error case: cannot Insert() a record that already has its ID field filled - md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"`). - ExpectQueryWithArgs("third"). - AndReturnColumns("id"). - WithRow(42) - err = store.Insert(db, &basicRecord{ID: 23, Name: "third"}) - assert.ErrEqual(t, err, `refusing to INSERT record with idx = 0 that already has non-zero values in its "auto" columns`) - }) + md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?) RETURNING "id"`). + ExpectQueryWithArgs("existing"). + AndReturnColumns("id"). + WithRow(42) + err := store.Insert(db, &basicRecord{ID: 23, Name: "third"}) + assert.ErrEqual(t, err, `refusing to INSERT record with idx = 0 that already has non-zero values in its "auto" columns`) } -func TestInsertWithoutAutoColumns(t *testing.T) { +func TestInsertWithNoAutoColumns(t *testing.T) { md := mock.NewDriver() db := sql.OpenDB(md) @@ -338,42 +231,14 @@ func TestInsertWithoutAutoColumns(t *testing.T) { FooID int64 `db:"foo_id"` BarID int64 `db:"bar_id"` } + store := oblast.MustNewStore[relation]( + oblast.SqliteDialect(), + oblast.TableNameIs("foo_bar_relations"), + oblast.PrimaryKeyIs("foo_id", "bar_id"), + ) - // Even in dialects using RETURNING clause, this uses Exec() because there is nothing to return. - // Therefore, the test behavior with both dialects is identical except for the different placeholder syntax in the query. - runTest := func(store oblast.Store[relation], query string) { - md.ForQuery(query). - ExpectExecWithArgs(1, 2). - AndReturnRowsAffected(1) - md.ForQuery(query). - ExpectExecWithArgs(1, 3). - AndReturnRowsAffected(1) - relations := []*relation{ - {FooID: 1, BarID: 2}, - {FooID: 1, BarID: 3}, - } - must.Succeed(t, store.Insert(db, relations...)) - assert.SliceDeepEqual(t, relations, - &relation{FooID: 1, BarID: 2}, - &relation{FooID: 1, BarID: 3}, - ) - } - - t.Run("in dialect using LastInsertID", func(t *testing.T) { - store := oblast.MustNewStore[relation]( - oblast.SqliteDialect(), - oblast.TableNameIs("foo_bar_relations"), - oblast.PrimaryKeyIs("foo_id", "bar_id"), - ) - runTest(store, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?)`) - }) - - t.Run("in dialect using RETURNING clause", func(t *testing.T) { - store := oblast.MustNewStore[relation]( - oblast.PostgresDialect(), - oblast.TableNameIs("foo_bar_relations"), - oblast.PrimaryKeyIs("foo_id", "bar_id"), - ) - runTest(store, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES ($1, $2)`) - }) + md.ForQuery(`INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?)`). + ExpectExecWithArgs(23, 42). + AndReturnRowsAffected(1) + must.Succeed(t, store.Insert(db, &relation{23, 42})) } |
