diff options
| author | Stefan Majewsky <majewsky@gmx.net> | 2026-04-18 15:44:46 +0200 |
|---|---|---|
| committer | Stefan Majewsky <majewsky@gmx.net> | 2026-04-18 15:45:42 +0200 |
| commit | 01d2d52fd7dfb64c41f7c94808fe01665ffcb881 (patch) | |
| tree | cd17477850ee1d34d11b681ef3f10a24b49c04f6 | |
| parent | cfdb06f5ba144aad5d2ebe31ec8bb64e017f4023 (diff) | |
| download | go-oblast-01d2d52fd7dfb64c41f7c94808fe01665ffcb881.tar.gz | |
more test coverage, forbid non-zero auto columns during Insert()
| -rw-r--r-- | errors.go | 29 | ||||
| -rw-r--r-- | errors_test.go | 40 | ||||
| -rw-r--r-- | query.go | 18 | ||||
| -rw-r--r-- | query_test.go | 201 |
4 files changed, 264 insertions, 24 deletions
@@ -27,9 +27,12 @@ func (e MissingRecordError[R]) Error() string { return "could not UPDATE record that does not exist in the database: " + strings.Join(keyDescs, ", ") } -// An error type that optionally contains either one of the following or both: -// - a core error from an IO operation (e.g. a database read) +// ioError is an error type that contains: +// - (optionally) a main error from an IO operation (e.g. a database read) // - an auxiliary error from closing or otherwise cleaning up the respective IO handle +// +// This is only used when there is a cleanup error. +// Otherwise, the main error will be returned without being wrapped in this type. type ioError struct { MainError error CleanupError error @@ -37,32 +40,26 @@ type ioError struct { } func newIOError(err error, cleanupOperation string, cleanupErr error) error { - if err == nil && cleanupErr == nil { - return nil + if cleanupErr == nil { + return err } return ioError{err, cleanupErr, cleanupOperation} } // Error implements the builtin/error interface. func (e ioError) Error() string { - switch { - case e.CleanupError == nil: - return e.MainError.Error() - case e.MainError == nil: + if e.MainError == nil { return fmt.Sprintf("during %s(): %s", e.CleanupOperation, e.CleanupError.Error()) - default: + } else { return fmt.Sprintf("%s (additional error during %s(): %s)", e.MainError.Error(), e.CleanupOperation, e.CleanupError.Error()) } } // Unwrap implements the interface implied by the documentation of package errors. func (e ioError) Unwrap() []error { - result := make([]error, 0, 2) - if e.MainError != nil { - result = append(result, e.MainError) + if e.MainError == nil { + return []error{e.CleanupError} + } else { + return []error{e.MainError, e.CleanupError} } - if e.CleanupError != nil { - result = append(result, e.CleanupError) - } - return result } diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..c39cf67 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net> +// SPDX-License-Identifier: Apache-2.0 + +package oblast + +import ( + "errors" + "testing" + + "go.xyrillian.de/oblast/internal/assert" +) + +type fooError struct{} +type barError struct{} +type bazError struct{} + +func (fooError) Error() string { return "foo" } +func (barError) Error() string { return "bar" } +func (bazError) Error() string { return "baz" } + +func TestIOError(t *testing.T) { + err := newIOError(nil, "File.Close", nil) + assert.Equal(t, err == nil, true) + + err = newIOError(fooError{}, "File.Close", nil) + assert.ErrEqual(t, err, "foo") + assert.DeepEqual(t, err, error(fooError{})) // check for no wrapping in type ioError without cleanup error + + err = newIOError(nil, "File.Close", barError{}) + assert.ErrEqual(t, err, "during File.Close(): bar") + assert.Equal(t, errors.Is(err, fooError{}), false) + assert.Equal(t, errors.Is(err, barError{}), true) + assert.Equal(t, errors.Is(err, bazError{}), false) + + err = newIOError(fooError{}, "File.Close", barError{}) + assert.ErrEqual(t, err, "foo (additional error during File.Close(): bar)") + assert.Equal(t, errors.Is(err, fooError{}), true) + assert.Equal(t, errors.Is(err, barError{}), true) + assert.Equal(t, errors.Is(err, bazError{}), false) +} @@ -75,6 +75,9 @@ func (s preparedStatement) QueryRow(args ...any) *sql.Row { // On success, returns the original set of records, updated thusly. // // Returns an error if [NewStore] was called without the [TableNameIs] option, which is required to generate a query for this method. +// +// Returns an error if any of the `records` has a non-zero value in any column marked as `db:",auto"`. +// Records that already exist in the database should be handled with [Store.Update] instead. func (s Store[R]) Insert(db Handle, records ...R) ([]R, error) { // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. // Any expression that does not depend on type R should be factored out into a reusable function. @@ -117,6 +120,11 @@ func insertRecordUsingLastInsertID(v reflect.Value, recordIndex int, stmt prepar for idx, index := range argumentIndexes { argumentSlots[idx] = v.FieldByIndex(index).Interface() } + scanField := v.FieldByIndex(scanIndex) + if !scanField.IsZero() { + return fmt.Errorf(`refusing to INSERT record with idx = %d that already has non-zero values in its "auto" columns`, recordIndex) + } + result, err := stmt.Exec(argumentSlots...) if err != nil { return fmt.Errorf("during Exec() for record with idx = %d: %w", recordIndex, err) @@ -126,12 +134,12 @@ func insertRecordUsingLastInsertID(v reflect.Value, recordIndex int, stmt prepar return fmt.Errorf("during LastInsertId() for record with idx = %d: %w", recordIndex, err) } if plan.FillIDWithSetInt { - v.FieldByIndex(scanIndex).SetInt(id) + scanField.SetInt(id) } else if plan.FillIDWithSetUint { if id < 0 { return fmt.Errorf("LastInsertId() = %d for record with idx = %d cannot be converted to uint", id, recordIndex) } - v.FieldByIndex(scanIndex).SetUint(uint64(id)) + scanField.SetUint(uint64(id)) } return nil } @@ -168,7 +176,11 @@ func insertRecordUsingReturningClause(v reflect.Value, recordIndex int, stmt pre argumentSlots[idx] = v.FieldByIndex(index).Interface() } for idx, index := range scanIndexes { - scanSlots[idx] = v.FieldByIndex(index).Addr().Interface() + f := v.FieldByIndex(index) + if !f.IsZero() { + return fmt.Errorf(`refusing to INSERT record with idx = %d that already has non-zero values in its "auto" columns`, recordIndex) + } + scanSlots[idx] = f.Addr().Interface() } err := stmt.QueryRow(argumentSlots...).Scan(scanSlots...) if err != nil { diff --git a/query_test.go b/query_test.go index 388183e..000c385 100644 --- a/query_test.go +++ b/query_test.go @@ -87,7 +87,7 @@ func TestUpdateBasic(t *testing.T) { Name string `db:"name"` } store := oblast.MustNewStore[basicRecord]( - oblast.PostgresDialect(), + oblast.SqliteDialect(), oblast.TableNameIs("basic_records"), oblast.PrimaryKeyIs("id"), ) @@ -98,7 +98,7 @@ func TestUpdateBasic(t *testing.T) { for idx := range batchSize { r := basicRecord{ID: int64(42 + idx), Name: "updated"} records[idx] = r - md.ForQuery(`UPDATE "basic_records" SET "name" = $1 WHERE "id" = $2`). + md.ForQuery(`UPDATE "basic_records" SET "name" = ? WHERE "id" = ?`). ExpectExecWithArgs(r.Name, r.ID). AndReturnRowsAffected(1) } @@ -116,7 +116,7 @@ func TestDeleteBasic(t *testing.T) { Name string `db:"name"` } store := oblast.MustNewStore[basicRecord]( - oblast.PostgresDialect(), + oblast.SqliteDialect(), oblast.TableNameIs("basic_records"), oblast.PrimaryKeyIs("id"), ) @@ -127,7 +127,7 @@ func TestDeleteBasic(t *testing.T) { for idx := range batchSize { r := basicRecord{ID: int64(42 + idx), Name: "removed"} records[idx] = r - md.ForQuery(`DELETE FROM "basic_records" WHERE "id" = $1`). + md.ForQuery(`DELETE FROM "basic_records" WHERE "id" = ?`). ExpectExecWithArgs(r.ID). AndReturnRowsAffected(1) } @@ -136,4 +136,195 @@ func TestDeleteBasic(t *testing.T) { } } -// TODO: more test coverage for query.go +func TestWriteQueriesNotPossible(t *testing.T) { + md := mock.NewDriver() + db := sql.OpenDB(md) + + type basicRecord struct { + ID int64 `db:"id,auto"` + Name string `db:"name"` + } + store := oblast.MustNewStore[basicRecord]( + oblast.SqliteDialect(), + // no TableNameIs() or PrimaryKeyIs() given + ) + + r := basicRecord{Name: "foo"} + _, err := store.Insert(db, r) + assert.ErrEqual(t, err, "cannot execute Insert() because query could not be autogenerated") + + r.ID = 42 + err = store.Update(db, r) + assert.ErrEqual(t, err, "cannot execute Update() because query could not be autogenerated") + + err = store.Delete(db, r) + assert.ErrEqual(t, err, "cannot execute Delete() because query could not be autogenerated") +} + +func TestWriteQueriesFailDuringPrepare(t *testing.T) { + md := mock.NewDriver() + db := sql.OpenDB(md) + + type basicRecord struct { + ID int64 `db:"id,auto"` + Name string `db:"name"` + } + store := oblast.MustNewStore[basicRecord]( + oblast.SqliteDialect(), + oblast.TableNameIs("basic_records"), + oblast.PrimaryKeyIs("id"), + ) + + for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} { + records := make([]basicRecord, batchSize) + for idx := range batchSize { + records[idx] = basicRecord{Name: "foo"} + } + + _, err := store.Insert(db, records...) + baseError := `unexpected query: INSERT INTO "basic_records" ("name") VALUES (?)` + if batchSize < oblast.PrepareThreshold { + assert.ErrEqual(t, err, "during Exec() for record with idx = 0: "+baseError) + } else { + assert.ErrEqual(t, err, "during Prepare(): "+baseError) + } + + for idx := range batchSize { + records[idx].ID = int64(42 + idx) + } + + err = store.Update(db, records...) + baseError = `unexpected query: UPDATE "basic_records" SET "name" = ? WHERE "id" = ?` + if batchSize < oblast.PrepareThreshold { + assert.ErrEqual(t, err, "during Exec() for record with idx = 0: "+baseError) + } else { + assert.ErrEqual(t, err, "during Prepare(): "+baseError) + } + + err = store.Delete(db, records...) + baseError = `unexpected query: DELETE FROM "basic_records" WHERE "id" = ?` + if batchSize < oblast.PrepareThreshold { + assert.ErrEqual(t, err, "during Exec() for record with idx = 0: "+baseError) + } else { + assert.ErrEqual(t, err, "during Prepare(): "+baseError) + } + } + + store = oblast.MustNewStore[basicRecord]( + oblast.PostgresDialect(), // for test coverage of insertUsingReturningClause() + oblast.TableNameIs("basic_records"), + oblast.PrimaryKeyIs("id"), + ) + + for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} { + records := make([]basicRecord, batchSize) + for idx := range batchSize { + records[idx] = basicRecord{Name: "foo"} + } + + _, err := store.Insert(db, records...) + baseError := `unexpected query: INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"` + if batchSize < oblast.PrepareThreshold { + assert.ErrEqual(t, err, "during QueryRow() for record with idx = 0: "+baseError) + } else { + assert.ErrEqual(t, err, "during Prepare(): "+baseError) + } + } +} + +func TestUpdateFailsOnMissingRecord(t *testing.T) { + md := mock.NewDriver() + db := sql.OpenDB(md) + + type basicRecord struct { + ID int64 `db:"id,auto"` + Name string `db:"name"` + } + store := oblast.MustNewStore[basicRecord]( + oblast.SqliteDialect(), + oblast.TableNameIs("basic_records"), + oblast.PrimaryKeyIs("id"), + ) + + md.ForQuery(`UPDATE "basic_records" SET "name" = ? WHERE "id" = ?`). + ExpectExecWithArgs("changed", 42). + AndReturnRowsAffected(0) + err := store.Update(db, basicRecord{ID: 42, Name: "changed"}) + assert.ErrEqual(t, err, "could not UPDATE record that does not exist in the database: id = 42") + _, hasCorrectType := err.(oblast.MissingRecordError[basicRecord]) //nolint:errorlint // we explicitly do not want a wrapped error + assert.Equal(t, hasCorrectType, true) +} + +func TestInsertWithUnsignedIdField(t *testing.T) { + md := mock.NewDriver() + db := sql.OpenDB(md) + + type basicRecord struct { + ID uint64 `db:"id,auto"` // not int64! + Name string `db:"name"` + } + + t.Run("using LastInsertID", func(t *testing.T) { + store := oblast.MustNewStore[basicRecord]( + oblast.SqliteDialect(), + oblast.TableNameIs("basic_records"), + oblast.PrimaryKeyIs("id"), + ) + + // success case + md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`). + ExpectExecWithArgs("first"). + AndReturnLastInsertId(42). + AndReturnRowsAffected(1) + records := must.Return(store.Insert(db, basicRecord{Name: "first"}))(t) + assert.SliceEqual(t, records, basicRecord{ID: 42, Name: "first"}) + + // error case: negative ID cannot be cast to uint64 + md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`). + ExpectExecWithArgs("second"). + AndReturnLastInsertId(-42). + AndReturnRowsAffected(1) + _, err := store.Insert(db, basicRecord{Name: "second"}) + assert.ErrEqual(t, err, "LastInsertId() = -42 for record with idx = 0 cannot be converted to uint") + + // error case: cannot Insert() a record that already has its ID field filled + md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`). + ExpectExecWithArgs("third"). + AndReturnLastInsertId(42). + AndReturnRowsAffected(1) + _, err = store.Insert(db, basicRecord{ID: 23, Name: "third"}) + assert.ErrEqual(t, err, `refusing to INSERT record with idx = 0 that already has non-zero values in its "auto" columns`) + }) + + t.Run("using RETURNING clause", func(t *testing.T) { + store := oblast.MustNewStore[basicRecord]( + oblast.PostgresDialect(), + oblast.TableNameIs("basic_records"), + oblast.PrimaryKeyIs("id"), + ) + + // success case + md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"`). + ExpectQueryWithArgs("first"). + AndReturnColumns("id"). + WithRow(42) + records := must.Return(store.Insert(db, basicRecord{Name: "first"}))(t) + assert.SliceEqual(t, records, basicRecord{ID: 42, Name: "first"}) + + // error case: negative ID cannot be cast to uint64 + md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"`). + ExpectQueryWithArgs("second"). + AndReturnColumns("id"). + WithRow(-42) + _, err := store.Insert(db, basicRecord{Name: "second"}) + assert.ErrEqual(t, err, `during QueryRow() for record with idx = 0: sql: Scan error on column index 0, name "id": converting driver.Value type int ("-42") to a uint64: invalid syntax`) + + // error case: cannot Insert() a record that already has its ID field filled + md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"`). + ExpectQueryWithArgs("third"). + AndReturnColumns("id"). + WithRow(42) + _, err = store.Insert(db, basicRecord{ID: 23, Name: "third"}) + assert.ErrEqual(t, err, `refusing to INSERT record with idx = 0 that already has non-zero values in its "auto" columns`) + }) +} |
