aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStefan Majewsky <majewsky@gmx.net>2026-05-13 00:39:22 +0200
committerStefan Majewsky <majewsky@gmx.net>2026-05-13 00:40:24 +0200
commit2fe6a5a42ccb663211f4f4804b78fff3bd9ebdc0 (patch)
tree08fe0bcc17dcd617d08a14847375710b80d86d8d
parenta86a346ecceb7ad409f116474c1593b201012cf2 (diff)
downloadgo-oblast-2fe6a5a42ccb663211f4f4804b78fff3bd9ebdc0.tar.gz
Insert, Upsert, Update, Delete: do not panic on indirection through nil pointer
-rw-r--r--CHANGELOG.md7
-rw-r--r--oblast.go6
-rw-r--r--plan.go36
-rw-r--r--query.go67
-rw-r--r--query_test.go73
-rw-r--r--select.go8
6 files changed, 169 insertions, 28 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index e9853f8..0f0aa71 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -3,6 +3,13 @@ SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net>
SPDX-License-Identifier: Apache-2.0
-->
+# v0.8.0 (TBD)
+
+Changes:
+
+- Insert, Upsert, Update and Delete will no longer panic when one of the fields they need to access is within a pointer-to-struct that is nil.
+ Instead, an error will be returned in a controlled manner.
+
# v0.7.0 (2026-05-12)
API changes:
diff --git a/oblast.go b/oblast.go
index 35b51e2..5adb33e 100644
--- a/oblast.go
+++ b/oblast.go
@@ -25,17 +25,19 @@
// Then use it many times to perform load and store operations:
//
// func doStuff(db *sql.DB) error {
+// dbh := oblast.Wrap(db)
+//
// newEntry := LogEntry{
// CreatedAt: time.Now(),
// Message: "Hello World.",
// }
-// err := logEntryStore.Insert(db, &newEntry)
+// err := logEntryStore.Insert(dbh, &newEntry)
// if err != nil {
// return err
// }
// fmt.Printf("created log entry %d", newEntry.ID)
//
-// allEntries, err := logEntryStore.SelectWhere(db, `created_at < NOW()`)
+// allEntries, err := logEntryStore.SelectWhere(dbh, `created_at < NOW()`)
// if err != nil {
// return err
// }
diff --git a/plan.go b/plan.go
index 9dc38f6..830899e 100644
--- a/plan.go
+++ b/plan.go
@@ -23,8 +23,8 @@ type plan struct {
// Field index (i.e. argument for reflect.Value.FieldByIndex()) for each column name.
IndexByColumnName map[string][]int
- // Indexes of pointer-typed fields that need to be initialized before scanning into this type.
- IndexesOfTransparentPointerStructs [][]int
+ // Pointer-typed fields that need to be initialized before scanning into this type.
+ TransparentPointerStructFields []fieldInfo
// Planned queries.
Select plannedQuery // only `SELECT ... FROM ... WHERE `; user supplies the rest during Select{,One}Where()
@@ -34,6 +34,13 @@ type plan struct {
Delete plannedQuery
}
+// fieldInfo appears in type plan.
+type fieldInfo struct {
+ Name string
+ Index []int
+ ContainsPrimaryKey bool
+}
+
// plannedQuery appears in type plan.
type plannedQuery struct {
// Empty if the respective query type is not supported by this plan for lack of the required marker types.
@@ -80,11 +87,6 @@ func buildPlan(t reflect.Type, dialect Dialect, opts planOpts) (plan, error) {
// discover addressable fields in this type, collect information from markers and tags
for _, field := range reflect.VisibleFields(t) {
- // ignore unexported fields (otherwise reflect.Value.Interface() on the field would panic)
- if field.PkgPath != "" {
- continue
- }
-
// recurse into struct fields (i.e. ignore the struct itself and consider its members instead)
// unless the field itself has a `db:"..."` tag
if field.Type.Kind() == reflect.Struct || (field.Type.Kind() == reflect.Pointer && field.Type.Elem().Kind() == reflect.Struct) {
@@ -93,13 +95,22 @@ func buildPlan(t reflect.Type, dialect Dialect, opts planOpts) (plan, error) {
if field.Type.Kind() == reflect.Pointer {
// remember that, when scanning into a record of type `t`, we need to write a non-nil zeroed struct into this field
// to enable taking an address of its mapped member fields
- p.IndexesOfTransparentPointerStructs = append(p.IndexesOfTransparentPointerStructs, field.Index)
+ p.TransparentPointerStructFields = append(p.TransparentPointerStructFields, fieldInfo{
+ Name: field.Name,
+ Index: field.Index,
+ ContainsPrimaryKey: false, // might be set later
+ })
}
continue
}
indexesOfOpaqueStructs = append(indexesOfOpaqueStructs, field.Index)
}
+ // ignore unexported fields (otherwise reflect.Value.Interface() on the field would panic)
+ if field.PkgPath != "" {
+ continue
+ }
+
// ignore fields that are within a struct type that is mapped as a whole
if slices.ContainsFunc(indexesOfOpaqueStructs, func(index []int) bool {
return isWithin(field.Index, index)
@@ -132,6 +143,15 @@ func buildPlan(t reflect.Type, dialect Dialect, opts planOpts) (plan, error) {
}
}
+ // track which transparent pointer structs contain PK fields
+ if slices.Contains(p.PrimaryKeyColumnNames, columnName) {
+ for idx, tpsField := range p.TransparentPointerStructFields {
+ if isWithin(field.Index, tpsField.Index) {
+ p.TransparentPointerStructFields[idx].ContainsPrimaryKey = true
+ }
+ }
+ }
+
for _, tag := range extraTags {
switch tag {
case "auto":
diff --git a/query.go b/query.go
index de151a6..eea1771 100644
--- a/query.go
+++ b/query.go
@@ -68,7 +68,11 @@ func (s Store[R]) insertUsing(ctx context.Context, stmt handle.Statement, db Han
for idx, r := range records {
v := reflect.ValueOf(r).Elem()
- err := insertRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots)
+ err := checkTransparentPointerStructFieldsInitialized("INSERT", idx, v, s.plan, false)
+ if err != nil {
+ return newIOError(err, "Stmt.Close", stmt.Close())
+ }
+ err = insertRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots)
if err != nil {
return newIOError(err, "Stmt.Close", stmt.Close())
}
@@ -78,7 +82,6 @@ func (s Store[R]) insertUsing(ctx context.Context, stmt handle.Statement, db Han
}
func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any, scanIndexes [][]int, scanSlots []any) error {
- // TODO: check plan.IndexesOfTransparentPointerStructs, return error (instead of panicking on FieldByIndex) if pointer struct is nil
for idx, index := range argumentIndexes {
argumentSlots[idx] = v.FieldByIndex(index).Interface()
}
@@ -102,6 +105,27 @@ func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt ha
return nil
}
+// This check must be performed within all query functions that access existing values using FieldByIndex(),
+// to ensure that FieldByIndex() does not panic on indirection through a nil pointer.
+func checkTransparentPointerStructFieldsInitialized(operation string, recordIndex int, v reflect.Value, plan plan, onlyPK bool) error {
+ for _, field := range plan.TransparentPointerStructFields {
+ f := v.FieldByIndex(field.Index)
+ if !f.IsZero() {
+ continue
+ }
+ if onlyPK {
+ if field.ContainsPrimaryKey {
+ return fmt.Errorf(`refusing to %s record with idx = %d: cannot access all primary key fields because field %q holds a nil pointer`,
+ operation, recordIndex, field.Name)
+ }
+ } else {
+ return fmt.Errorf(`refusing to %s record with idx = %d: cannot access all mapped fields because field %q holds a nil pointer`,
+ operation, recordIndex, field.Name)
+ }
+ }
+ return nil
+}
+
// Update executes an SQL UPDATE statement for each of the provided records, updating all non-primary-key columns with the values in the records.
// Returns [MissingRecordError] if any of the records does not exist in the database, that is, if for any of the records, the database contains no row with the same primary key values.
//
@@ -122,6 +146,10 @@ func (s Store[R]) Update(ctx context.Context, db Handle, records ...R) error {
for idx := range records {
v := reflect.ValueOf(&records[idx]).Elem()
+ err := checkTransparentPointerStructFieldsInitialized("UPDATE", idx, v, s.plan, false)
+ if err != nil {
+ return newIOError(err, "Stmt.Close", stmt.Close())
+ }
rowsAffected, err := updateRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots)
if err == nil && rowsAffected == 0 {
err = MissingRecordError[R]{records[idx], s.plan}
@@ -134,7 +162,6 @@ func (s Store[R]) Update(ctx context.Context, db Handle, records ...R) error {
}
func updateRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any) (int64, error) {
- // TODO: check plan.IndexesOfTransparentPointerStructs, return error (instead of panicking on FieldByIndex) if pointer struct is nil
for idx, index := range argumentIndexes {
argumentSlots[idx] = v.FieldByIndex(index).Interface()
}
@@ -168,7 +195,7 @@ func (s Store[R]) Delete(ctx context.Context, db Handle, records ...R) error {
for idx := range records {
v := reflect.ValueOf(&records[idx]).Elem()
- err := deleteRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots)
+ err := deleteRecord(ctx, s.plan, v, idx, stmt, argumentIndexes, argumentSlots)
if err != nil {
return newIOError(err, "Stmt.Close", stmt.Close())
}
@@ -177,13 +204,15 @@ func (s Store[R]) Delete(ctx context.Context, db Handle, records ...R) error {
return newIOError(nil, "Stmt.Close", stmt.Close())
}
-func deleteRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any) error {
- // TODO: consider checking plan.IndexesOfTransparentPointerStructs and returning an error (instead of panicking on FieldByIndex) if pointer struct is nil
- // (might want to have bookkeeping to only check pointer structs that lead to PK fields)
+func deleteRecord(ctx context.Context, plan plan, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any) error {
+ err := checkTransparentPointerStructFieldsInitialized("DELETE", recordIndex, v, plan, true)
+ if err != nil {
+ return newIOError(err, "Stmt.Close", stmt.Close())
+ }
for idx, index := range argumentIndexes {
argumentSlots[idx] = v.FieldByIndex(index).Interface()
}
- _, err := stmt.Exec(ctx, argumentSlots)
+ _, err = stmt.Exec(ctx, argumentSlots)
if err != nil {
return fmt.Errorf("while deleting record with idx = %d: %w", recordIndex, err)
}
@@ -217,9 +246,19 @@ func (s Store[R]) Upsert(ctx context.Context, db Handle, records ...*R) error {
}
updateStmt, err := prepare(ctx, db, s.plan.Update.Query, "Update", 0)
if err != nil {
- return err
+ return newIOError(err, "InsertStmt.Close", insertStmt.Close())
}
+ err = s.doUpsert(ctx, db, insertStmt, updateStmt, records)
+ err = newIOError(err, "InsertStmt.Close", insertStmt.Close())
+ err = newIOError(err, "UpdateStmt.Close", updateStmt.Close())
+ return err
+}
+
+func (s Store[R]) doUpsert(ctx context.Context, db Handle, insertStmt, updateStmt handle.Statement, records []*R) error {
+ // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization.
+ // Any expression that does not depend on type R should be factored out into a reusable function.
+
var (
insertArgumentIndexes = s.plan.Insert.ArgumentIndexes
insertArgumentSlots = make([]any, len(insertArgumentIndexes))
@@ -231,6 +270,10 @@ func (s Store[R]) Upsert(ctx context.Context, db Handle, records ...*R) error {
for idx, r := range records {
v := reflect.ValueOf(r).Elem()
+ err := checkTransparentPointerStructFieldsInitialized("INSERT or UPDATE", idx, v, s.plan, false)
+ if err != nil {
+ return err
+ }
isInsert, err := upsertDecideStrategy(v, idx, insertScanIndexes)
if err != nil {
return err
@@ -246,15 +289,11 @@ func (s Store[R]) Upsert(ctx context.Context, db Handle, records ...*R) error {
}
}
if err != nil {
- err = newIOError(err, "InsertStmt.Close", insertStmt.Close())
- err = newIOError(err, "UpdateStmt.Close", updateStmt.Close())
return err
}
}
- err = newIOError(err, "InsertStmt.Close", insertStmt.Close())
- err = newIOError(err, "UpdateStmt.Close", updateStmt.Close())
- return err
+ return nil
}
func upsertDecideStrategy(v reflect.Value, recordIndex int, scanIndexes [][]int) (isInsert bool, err error) {
diff --git a/query_test.go b/query_test.go
index 41b0203..382a463 100644
--- a/query_test.go
+++ b/query_test.go
@@ -346,3 +346,76 @@ func TestUpsertFailsOnMixedAutoFieldState(t *testing.T) {
err := store.Upsert(ctx, db, &brokenRecord)
assert.ErrEqual(t, err, `cannot decide whether to INSERT or UPDATE record with idx = 0: some "auto" columns are zero, others are not`)
}
+
+func TestUninitializedTransparentPointerStructs(t *testing.T) {
+ ctx := t.Context()
+ md := mock.NewDriver()
+ db := oblast.Wrap(sql.OpenDB(md))
+
+ // declare a record type that has a transparent pointer struct containing non-primary-key fields
+ type timestamps struct {
+ CreatedAt time.Time `db:"created_at"`
+ DeletedAt *time.Time `db:"deleted_at"`
+ }
+ type nestedRecord struct {
+ ID int64 `db:"id,auto"`
+ Name string `db:"name"`
+ *timestamps
+ }
+ nestedRecordStore := oblast.MustNewStore[nestedRecord](
+ oblast.SqliteDialect(),
+ oblast.TableNameIs("nested_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ // declare another record type that has a primary key field within a transparent pointer struct
+ type commonFields struct {
+ ID int64 `db:"id,auto"`
+ CreatedAt time.Time `db:"created_at"`
+ DeletedAt *time.Time `db:"deleted_at"`
+ }
+ type weirdRecord struct {
+ *commonFields
+ Name string `db:"name"`
+ }
+ weirdRecordStore := oblast.MustNewStore[weirdRecord](
+ oblast.SqliteDialect(),
+ oblast.TableNameIs("weird_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ // check detection on INSERT
+ freshBrokenRecord := nestedRecord{
+ Name: "foo",
+ timestamps: nil, // problem: cannot access `freshBrokenRecord.CreatedAt` or `freshBrokenRecord.DeletedAt`
+ }
+ err := nestedRecordStore.Insert(ctx, db, &freshBrokenRecord)
+ assert.ErrEqual(t, err, `refusing to INSERT record with idx = 0: cannot access all mapped fields because field "timestamps" holds a nil pointer`)
+ err = nestedRecordStore.Upsert(ctx, db, &freshBrokenRecord)
+ assert.ErrEqual(t, err, `refusing to INSERT or UPDATE record with idx = 0: cannot access all mapped fields because field "timestamps" holds a nil pointer`)
+
+ // check detection on UPDATE
+ existingBrokenRecord := nestedRecord{
+ ID: 42,
+ Name: "bar",
+ timestamps: nil, // same problem as above
+ }
+ err = nestedRecordStore.Update(ctx, db, existingBrokenRecord)
+ assert.ErrEqual(t, err, `refusing to UPDATE record with idx = 0: cannot access all mapped fields because field "timestamps" holds a nil pointer`)
+ err = nestedRecordStore.Upsert(ctx, db, &freshBrokenRecord)
+ assert.ErrEqual(t, err, `refusing to INSERT or UPDATE record with idx = 0: cannot access all mapped fields because field "timestamps" holds a nil pointer`)
+
+ // check that detection on DELETE does not care about transparent pointer structs as long as they do not contain PK fields
+ md.ForQuery(`DELETE FROM "nested_records" WHERE "id" = ?`).
+ ExpectExecWithArgs(42).
+ AndReturnRowsAffected(1)
+ must.Succeed(t, nestedRecordStore.Delete(ctx, db, existingBrokenRecord))
+
+ // check detection on DELETE where it matters
+ existingWeirdRecord := weirdRecord{
+ commonFields: nil, // problem: cannot access `existingWeirdRecord.ID`
+ Name: "qux",
+ }
+ err = weirdRecordStore.Delete(ctx, db, existingWeirdRecord)
+ assert.ErrEqual(t, err, `refusing to DELETE record with idx = 0: cannot access all primary key fields because field "commonFields" holds a nil pointer`)
+}
diff --git a/select.go b/select.go
index 87dd36b..17195d0 100644
--- a/select.go
+++ b/select.go
@@ -144,8 +144,8 @@ func startSelectWhereQuery(ctx context.Context, db Handle, plan plan, partialQue
}
func collectRow(rows handle.Rows, plan plan, v reflect.Value, slots []any, indexes [][]int) error {
- for _, index := range plan.IndexesOfTransparentPointerStructs {
- f := v.FieldByIndex(index)
+ for _, field := range plan.TransparentPointerStructFields {
+ f := v.FieldByIndex(field.Index)
f.Set(reflect.New(f.Type().Elem()))
}
for idx, index := range indexes {
@@ -232,8 +232,8 @@ func selectOneWhere(ctx context.Context, db Handle, plan plan, v reflect.Value,
}
func selectOne(ctx context.Context, db Handle, plan plan, v reflect.Value, query string, args []any) error {
- for _, index := range plan.IndexesOfTransparentPointerStructs {
- f := v.FieldByIndex(index)
+ for _, field := range plan.TransparentPointerStructFields {
+ f := v.FieldByIndex(field.Index)
f.Set(reflect.New(f.Type().Elem()))
}
slots := make([]any, len(plan.Select.ScanIndexes))