aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStefan Majewsky <majewsky@gmx.net>2026-04-18 15:44:46 +0200
committerStefan Majewsky <majewsky@gmx.net>2026-04-18 15:45:42 +0200
commit01d2d52fd7dfb64c41f7c94808fe01665ffcb881 (patch)
treecd17477850ee1d34d11b681ef3f10a24b49c04f6
parentcfdb06f5ba144aad5d2ebe31ec8bb64e017f4023 (diff)
downloadgo-oblast-01d2d52fd7dfb64c41f7c94808fe01665ffcb881.tar.gz
more test coverage, forbid non-zero auto columns during Insert()
-rw-r--r--errors.go29
-rw-r--r--errors_test.go40
-rw-r--r--query.go18
-rw-r--r--query_test.go201
4 files changed, 264 insertions, 24 deletions
diff --git a/errors.go b/errors.go
index 1e81060..4002f58 100644
--- a/errors.go
+++ b/errors.go
@@ -27,9 +27,12 @@ func (e MissingRecordError[R]) Error() string {
return "could not UPDATE record that does not exist in the database: " + strings.Join(keyDescs, ", ")
}
-// An error type that optionally contains either one of the following or both:
-// - a core error from an IO operation (e.g. a database read)
+// ioError is an error type that contains:
+// - (optionally) a main error from an IO operation (e.g. a database read)
// - an auxiliary error from closing or otherwise cleaning up the respective IO handle
+//
+// This is only used when there is a cleanup error.
+// Otherwise, the main error will be returned without being wrapped in this type.
type ioError struct {
MainError error
CleanupError error
@@ -37,32 +40,26 @@ type ioError struct {
}
func newIOError(err error, cleanupOperation string, cleanupErr error) error {
- if err == nil && cleanupErr == nil {
- return nil
+ if cleanupErr == nil {
+ return err
}
return ioError{err, cleanupErr, cleanupOperation}
}
// Error implements the builtin/error interface.
func (e ioError) Error() string {
- switch {
- case e.CleanupError == nil:
- return e.MainError.Error()
- case e.MainError == nil:
+ if e.MainError == nil {
return fmt.Sprintf("during %s(): %s", e.CleanupOperation, e.CleanupError.Error())
- default:
+ } else {
return fmt.Sprintf("%s (additional error during %s(): %s)", e.MainError.Error(), e.CleanupOperation, e.CleanupError.Error())
}
}
// Unwrap implements the interface implied by the documentation of package errors.
func (e ioError) Unwrap() []error {
- result := make([]error, 0, 2)
- if e.MainError != nil {
- result = append(result, e.MainError)
+ if e.MainError == nil {
+ return []error{e.CleanupError}
+ } else {
+ return []error{e.MainError, e.CleanupError}
}
- if e.CleanupError != nil {
- result = append(result, e.CleanupError)
- }
- return result
}
diff --git a/errors_test.go b/errors_test.go
new file mode 100644
index 0000000..c39cf67
--- /dev/null
+++ b/errors_test.go
@@ -0,0 +1,40 @@
+// SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net>
+// SPDX-License-Identifier: Apache-2.0
+
+package oblast
+
+import (
+ "errors"
+ "testing"
+
+ "go.xyrillian.de/oblast/internal/assert"
+)
+
+type fooError struct{}
+type barError struct{}
+type bazError struct{}
+
+func (fooError) Error() string { return "foo" }
+func (barError) Error() string { return "bar" }
+func (bazError) Error() string { return "baz" }
+
+func TestIOError(t *testing.T) {
+ err := newIOError(nil, "File.Close", nil)
+ assert.Equal(t, err == nil, true)
+
+ err = newIOError(fooError{}, "File.Close", nil)
+ assert.ErrEqual(t, err, "foo")
+ assert.DeepEqual(t, err, error(fooError{})) // check for no wrapping in type ioError without cleanup error
+
+ err = newIOError(nil, "File.Close", barError{})
+ assert.ErrEqual(t, err, "during File.Close(): bar")
+ assert.Equal(t, errors.Is(err, fooError{}), false)
+ assert.Equal(t, errors.Is(err, barError{}), true)
+ assert.Equal(t, errors.Is(err, bazError{}), false)
+
+ err = newIOError(fooError{}, "File.Close", barError{})
+ assert.ErrEqual(t, err, "foo (additional error during File.Close(): bar)")
+ assert.Equal(t, errors.Is(err, fooError{}), true)
+ assert.Equal(t, errors.Is(err, barError{}), true)
+ assert.Equal(t, errors.Is(err, bazError{}), false)
+}
diff --git a/query.go b/query.go
index 2ed113a..04ba647 100644
--- a/query.go
+++ b/query.go
@@ -75,6 +75,9 @@ func (s preparedStatement) QueryRow(args ...any) *sql.Row {
// On success, returns the original set of records, updated thusly.
//
// Returns an error if [NewStore] was called without the [TableNameIs] option, which is required to generate a query for this method.
+//
+// Returns an error if any of the `records` has a non-zero value in any column marked as `db:",auto"`.
+// Records that already exist in the database should be handled with [Store.Update] instead.
func (s Store[R]) Insert(db Handle, records ...R) ([]R, error) {
// NOTE: This function body should be as short as possible to reduce the binary size after monomorphization.
// Any expression that does not depend on type R should be factored out into a reusable function.
@@ -117,6 +120,11 @@ func insertRecordUsingLastInsertID(v reflect.Value, recordIndex int, stmt prepar
for idx, index := range argumentIndexes {
argumentSlots[idx] = v.FieldByIndex(index).Interface()
}
+ scanField := v.FieldByIndex(scanIndex)
+ if !scanField.IsZero() {
+ return fmt.Errorf(`refusing to INSERT record with idx = %d that already has non-zero values in its "auto" columns`, recordIndex)
+ }
+
result, err := stmt.Exec(argumentSlots...)
if err != nil {
return fmt.Errorf("during Exec() for record with idx = %d: %w", recordIndex, err)
@@ -126,12 +134,12 @@ func insertRecordUsingLastInsertID(v reflect.Value, recordIndex int, stmt prepar
return fmt.Errorf("during LastInsertId() for record with idx = %d: %w", recordIndex, err)
}
if plan.FillIDWithSetInt {
- v.FieldByIndex(scanIndex).SetInt(id)
+ scanField.SetInt(id)
} else if plan.FillIDWithSetUint {
if id < 0 {
return fmt.Errorf("LastInsertId() = %d for record with idx = %d cannot be converted to uint", id, recordIndex)
}
- v.FieldByIndex(scanIndex).SetUint(uint64(id))
+ scanField.SetUint(uint64(id))
}
return nil
}
@@ -168,7 +176,11 @@ func insertRecordUsingReturningClause(v reflect.Value, recordIndex int, stmt pre
argumentSlots[idx] = v.FieldByIndex(index).Interface()
}
for idx, index := range scanIndexes {
- scanSlots[idx] = v.FieldByIndex(index).Addr().Interface()
+ f := v.FieldByIndex(index)
+ if !f.IsZero() {
+ return fmt.Errorf(`refusing to INSERT record with idx = %d that already has non-zero values in its "auto" columns`, recordIndex)
+ }
+ scanSlots[idx] = f.Addr().Interface()
}
err := stmt.QueryRow(argumentSlots...).Scan(scanSlots...)
if err != nil {
diff --git a/query_test.go b/query_test.go
index 388183e..000c385 100644
--- a/query_test.go
+++ b/query_test.go
@@ -87,7 +87,7 @@ func TestUpdateBasic(t *testing.T) {
Name string `db:"name"`
}
store := oblast.MustNewStore[basicRecord](
- oblast.PostgresDialect(),
+ oblast.SqliteDialect(),
oblast.TableNameIs("basic_records"),
oblast.PrimaryKeyIs("id"),
)
@@ -98,7 +98,7 @@ func TestUpdateBasic(t *testing.T) {
for idx := range batchSize {
r := basicRecord{ID: int64(42 + idx), Name: "updated"}
records[idx] = r
- md.ForQuery(`UPDATE "basic_records" SET "name" = $1 WHERE "id" = $2`).
+ md.ForQuery(`UPDATE "basic_records" SET "name" = ? WHERE "id" = ?`).
ExpectExecWithArgs(r.Name, r.ID).
AndReturnRowsAffected(1)
}
@@ -116,7 +116,7 @@ func TestDeleteBasic(t *testing.T) {
Name string `db:"name"`
}
store := oblast.MustNewStore[basicRecord](
- oblast.PostgresDialect(),
+ oblast.SqliteDialect(),
oblast.TableNameIs("basic_records"),
oblast.PrimaryKeyIs("id"),
)
@@ -127,7 +127,7 @@ func TestDeleteBasic(t *testing.T) {
for idx := range batchSize {
r := basicRecord{ID: int64(42 + idx), Name: "removed"}
records[idx] = r
- md.ForQuery(`DELETE FROM "basic_records" WHERE "id" = $1`).
+ md.ForQuery(`DELETE FROM "basic_records" WHERE "id" = ?`).
ExpectExecWithArgs(r.ID).
AndReturnRowsAffected(1)
}
@@ -136,4 +136,195 @@ func TestDeleteBasic(t *testing.T) {
}
}
-// TODO: more test coverage for query.go
+func TestWriteQueriesNotPossible(t *testing.T) {
+ md := mock.NewDriver()
+ db := sql.OpenDB(md)
+
+ type basicRecord struct {
+ ID int64 `db:"id,auto"`
+ Name string `db:"name"`
+ }
+ store := oblast.MustNewStore[basicRecord](
+ oblast.SqliteDialect(),
+ // no TableNameIs() or PrimaryKeyIs() given
+ )
+
+ r := basicRecord{Name: "foo"}
+ _, err := store.Insert(db, r)
+ assert.ErrEqual(t, err, "cannot execute Insert() because query could not be autogenerated")
+
+ r.ID = 42
+ err = store.Update(db, r)
+ assert.ErrEqual(t, err, "cannot execute Update() because query could not be autogenerated")
+
+ err = store.Delete(db, r)
+ assert.ErrEqual(t, err, "cannot execute Delete() because query could not be autogenerated")
+}
+
+func TestWriteQueriesFailDuringPrepare(t *testing.T) {
+ md := mock.NewDriver()
+ db := sql.OpenDB(md)
+
+ type basicRecord struct {
+ ID int64 `db:"id,auto"`
+ Name string `db:"name"`
+ }
+ store := oblast.MustNewStore[basicRecord](
+ oblast.SqliteDialect(),
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} {
+ records := make([]basicRecord, batchSize)
+ for idx := range batchSize {
+ records[idx] = basicRecord{Name: "foo"}
+ }
+
+ _, err := store.Insert(db, records...)
+ baseError := `unexpected query: INSERT INTO "basic_records" ("name") VALUES (?)`
+ if batchSize < oblast.PrepareThreshold {
+ assert.ErrEqual(t, err, "during Exec() for record with idx = 0: "+baseError)
+ } else {
+ assert.ErrEqual(t, err, "during Prepare(): "+baseError)
+ }
+
+ for idx := range batchSize {
+ records[idx].ID = int64(42 + idx)
+ }
+
+ err = store.Update(db, records...)
+ baseError = `unexpected query: UPDATE "basic_records" SET "name" = ? WHERE "id" = ?`
+ if batchSize < oblast.PrepareThreshold {
+ assert.ErrEqual(t, err, "during Exec() for record with idx = 0: "+baseError)
+ } else {
+ assert.ErrEqual(t, err, "during Prepare(): "+baseError)
+ }
+
+ err = store.Delete(db, records...)
+ baseError = `unexpected query: DELETE FROM "basic_records" WHERE "id" = ?`
+ if batchSize < oblast.PrepareThreshold {
+ assert.ErrEqual(t, err, "during Exec() for record with idx = 0: "+baseError)
+ } else {
+ assert.ErrEqual(t, err, "during Prepare(): "+baseError)
+ }
+ }
+
+ store = oblast.MustNewStore[basicRecord](
+ oblast.PostgresDialect(), // for test coverage of insertUsingReturningClause()
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} {
+ records := make([]basicRecord, batchSize)
+ for idx := range batchSize {
+ records[idx] = basicRecord{Name: "foo"}
+ }
+
+ _, err := store.Insert(db, records...)
+ baseError := `unexpected query: INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"`
+ if batchSize < oblast.PrepareThreshold {
+ assert.ErrEqual(t, err, "during QueryRow() for record with idx = 0: "+baseError)
+ } else {
+ assert.ErrEqual(t, err, "during Prepare(): "+baseError)
+ }
+ }
+}
+
+func TestUpdateFailsOnMissingRecord(t *testing.T) {
+ md := mock.NewDriver()
+ db := sql.OpenDB(md)
+
+ type basicRecord struct {
+ ID int64 `db:"id,auto"`
+ Name string `db:"name"`
+ }
+ store := oblast.MustNewStore[basicRecord](
+ oblast.SqliteDialect(),
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ md.ForQuery(`UPDATE "basic_records" SET "name" = ? WHERE "id" = ?`).
+ ExpectExecWithArgs("changed", 42).
+ AndReturnRowsAffected(0)
+ err := store.Update(db, basicRecord{ID: 42, Name: "changed"})
+ assert.ErrEqual(t, err, "could not UPDATE record that does not exist in the database: id = 42")
+ _, hasCorrectType := err.(oblast.MissingRecordError[basicRecord]) //nolint:errorlint // we explicitly do not want a wrapped error
+ assert.Equal(t, hasCorrectType, true)
+}
+
+func TestInsertWithUnsignedIdField(t *testing.T) {
+ md := mock.NewDriver()
+ db := sql.OpenDB(md)
+
+ type basicRecord struct {
+ ID uint64 `db:"id,auto"` // not int64!
+ Name string `db:"name"`
+ }
+
+ t.Run("using LastInsertID", func(t *testing.T) {
+ store := oblast.MustNewStore[basicRecord](
+ oblast.SqliteDialect(),
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ // success case
+ md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`).
+ ExpectExecWithArgs("first").
+ AndReturnLastInsertId(42).
+ AndReturnRowsAffected(1)
+ records := must.Return(store.Insert(db, basicRecord{Name: "first"}))(t)
+ assert.SliceEqual(t, records, basicRecord{ID: 42, Name: "first"})
+
+ // error case: negative ID cannot be cast to uint64
+ md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`).
+ ExpectExecWithArgs("second").
+ AndReturnLastInsertId(-42).
+ AndReturnRowsAffected(1)
+ _, err := store.Insert(db, basicRecord{Name: "second"})
+ assert.ErrEqual(t, err, "LastInsertId() = -42 for record with idx = 0 cannot be converted to uint")
+
+ // error case: cannot Insert() a record that already has its ID field filled
+ md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`).
+ ExpectExecWithArgs("third").
+ AndReturnLastInsertId(42).
+ AndReturnRowsAffected(1)
+ _, err = store.Insert(db, basicRecord{ID: 23, Name: "third"})
+ assert.ErrEqual(t, err, `refusing to INSERT record with idx = 0 that already has non-zero values in its "auto" columns`)
+ })
+
+ t.Run("using RETURNING clause", func(t *testing.T) {
+ store := oblast.MustNewStore[basicRecord](
+ oblast.PostgresDialect(),
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ // success case
+ md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"`).
+ ExpectQueryWithArgs("first").
+ AndReturnColumns("id").
+ WithRow(42)
+ records := must.Return(store.Insert(db, basicRecord{Name: "first"}))(t)
+ assert.SliceEqual(t, records, basicRecord{ID: 42, Name: "first"})
+
+ // error case: negative ID cannot be cast to uint64
+ md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"`).
+ ExpectQueryWithArgs("second").
+ AndReturnColumns("id").
+ WithRow(-42)
+ _, err := store.Insert(db, basicRecord{Name: "second"})
+ assert.ErrEqual(t, err, `during QueryRow() for record with idx = 0: sql: Scan error on column index 0, name "id": converting driver.Value type int ("-42") to a uint64: invalid syntax`)
+
+ // error case: cannot Insert() a record that already has its ID field filled
+ md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"`).
+ ExpectQueryWithArgs("third").
+ AndReturnColumns("id").
+ WithRow(42)
+ _, err = store.Insert(db, basicRecord{ID: 23, Name: "third"})
+ assert.ErrEqual(t, err, `refusing to INSERT record with idx = 0 that already has non-zero values in its "auto" columns`)
+ })
+}