From 2fe6a5a42ccb663211f4f4804b78fff3bd9ebdc0 Mon Sep 17 00:00:00 2001 From: Stefan Majewsky Date: Wed, 13 May 2026 00:39:22 +0200 Subject: Insert, Upsert, Update, Delete: do not panic on indirection through nil pointer --- CHANGELOG.md | 7 ++++++ oblast.go | 6 +++-- plan.go | 36 ++++++++++++++++++++++------- query.go | 67 ++++++++++++++++++++++++++++++++++++++++++------------ query_test.go | 73 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ select.go | 8 +++---- 6 files changed, 169 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e9853f8..0f0aa71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,13 @@ SPDX-FileCopyrightText: 2026 Stefan Majewsky SPDX-License-Identifier: Apache-2.0 --> +# v0.8.0 (TBD) + +Changes: + +- Insert, Upsert, Update and Delete will no longer panic when one of the fields they need to access is within a pointer-to-struct that is nil. + Instead, an error will be returned in a controlled manner. + # v0.7.0 (2026-05-12) API changes: diff --git a/oblast.go b/oblast.go index 35b51e2..5adb33e 100644 --- a/oblast.go +++ b/oblast.go @@ -25,17 +25,19 @@ // Then use it many times to perform load and store operations: // // func doStuff(db *sql.DB) error { +// dbh := oblast.Wrap(db) +// // newEntry := LogEntry{ // CreatedAt: time.Now(), // Message: "Hello World.", // } -// err := logEntryStore.Insert(db, &newEntry) +// err := logEntryStore.Insert(dbh, &newEntry) // if err != nil { // return err // } // fmt.Printf("created log entry %d", newEntry.ID) // -// allEntries, err := logEntryStore.SelectWhere(db, `created_at < NOW()`) +// allEntries, err := logEntryStore.SelectWhere(dbh, `created_at < NOW()`) // if err != nil { // return err // } diff --git a/plan.go b/plan.go index 9dc38f6..830899e 100644 --- a/plan.go +++ b/plan.go @@ -23,8 +23,8 @@ type plan struct { // Field index (i.e. argument for reflect.Value.FieldByIndex()) for each column name. IndexByColumnName map[string][]int - // Indexes of pointer-typed fields that need to be initialized before scanning into this type. - IndexesOfTransparentPointerStructs [][]int + // Pointer-typed fields that need to be initialized before scanning into this type. + TransparentPointerStructFields []fieldInfo // Planned queries. Select plannedQuery // only `SELECT ... FROM ... WHERE `; user supplies the rest during Select{,One}Where() @@ -34,6 +34,13 @@ type plan struct { Delete plannedQuery } +// fieldInfo appears in type plan. +type fieldInfo struct { + Name string + Index []int + ContainsPrimaryKey bool +} + // plannedQuery appears in type plan. type plannedQuery struct { // Empty if the respective query type is not supported by this plan for lack of the required marker types. @@ -80,11 +87,6 @@ func buildPlan(t reflect.Type, dialect Dialect, opts planOpts) (plan, error) { // discover addressable fields in this type, collect information from markers and tags for _, field := range reflect.VisibleFields(t) { - // ignore unexported fields (otherwise reflect.Value.Interface() on the field would panic) - if field.PkgPath != "" { - continue - } - // recurse into struct fields (i.e. ignore the struct itself and consider its members instead) // unless the field itself has a `db:"..."` tag if field.Type.Kind() == reflect.Struct || (field.Type.Kind() == reflect.Pointer && field.Type.Elem().Kind() == reflect.Struct) { @@ -93,13 +95,22 @@ func buildPlan(t reflect.Type, dialect Dialect, opts planOpts) (plan, error) { if field.Type.Kind() == reflect.Pointer { // remember that, when scanning into a record of type `t`, we need to write a non-nil zeroed struct into this field // to enable taking an address of its mapped member fields - p.IndexesOfTransparentPointerStructs = append(p.IndexesOfTransparentPointerStructs, field.Index) + p.TransparentPointerStructFields = append(p.TransparentPointerStructFields, fieldInfo{ + Name: field.Name, + Index: field.Index, + ContainsPrimaryKey: false, // might be set later + }) } continue } indexesOfOpaqueStructs = append(indexesOfOpaqueStructs, field.Index) } + // ignore unexported fields (otherwise reflect.Value.Interface() on the field would panic) + if field.PkgPath != "" { + continue + } + // ignore fields that are within a struct type that is mapped as a whole if slices.ContainsFunc(indexesOfOpaqueStructs, func(index []int) bool { return isWithin(field.Index, index) @@ -132,6 +143,15 @@ func buildPlan(t reflect.Type, dialect Dialect, opts planOpts) (plan, error) { } } + // track which transparent pointer structs contain PK fields + if slices.Contains(p.PrimaryKeyColumnNames, columnName) { + for idx, tpsField := range p.TransparentPointerStructFields { + if isWithin(field.Index, tpsField.Index) { + p.TransparentPointerStructFields[idx].ContainsPrimaryKey = true + } + } + } + for _, tag := range extraTags { switch tag { case "auto": diff --git a/query.go b/query.go index de151a6..eea1771 100644 --- a/query.go +++ b/query.go @@ -68,7 +68,11 @@ func (s Store[R]) insertUsing(ctx context.Context, stmt handle.Statement, db Han for idx, r := range records { v := reflect.ValueOf(r).Elem() - err := insertRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots) + err := checkTransparentPointerStructFieldsInitialized("INSERT", idx, v, s.plan, false) + if err != nil { + return newIOError(err, "Stmt.Close", stmt.Close()) + } + err = insertRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots) if err != nil { return newIOError(err, "Stmt.Close", stmt.Close()) } @@ -78,7 +82,6 @@ func (s Store[R]) insertUsing(ctx context.Context, stmt handle.Statement, db Han } func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any, scanIndexes [][]int, scanSlots []any) error { - // TODO: check plan.IndexesOfTransparentPointerStructs, return error (instead of panicking on FieldByIndex) if pointer struct is nil for idx, index := range argumentIndexes { argumentSlots[idx] = v.FieldByIndex(index).Interface() } @@ -102,6 +105,27 @@ func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt ha return nil } +// This check must be performed within all query functions that access existing values using FieldByIndex(), +// to ensure that FieldByIndex() does not panic on indirection through a nil pointer. +func checkTransparentPointerStructFieldsInitialized(operation string, recordIndex int, v reflect.Value, plan plan, onlyPK bool) error { + for _, field := range plan.TransparentPointerStructFields { + f := v.FieldByIndex(field.Index) + if !f.IsZero() { + continue + } + if onlyPK { + if field.ContainsPrimaryKey { + return fmt.Errorf(`refusing to %s record with idx = %d: cannot access all primary key fields because field %q holds a nil pointer`, + operation, recordIndex, field.Name) + } + } else { + return fmt.Errorf(`refusing to %s record with idx = %d: cannot access all mapped fields because field %q holds a nil pointer`, + operation, recordIndex, field.Name) + } + } + return nil +} + // Update executes an SQL UPDATE statement for each of the provided records, updating all non-primary-key columns with the values in the records. // Returns [MissingRecordError] if any of the records does not exist in the database, that is, if for any of the records, the database contains no row with the same primary key values. // @@ -122,6 +146,10 @@ func (s Store[R]) Update(ctx context.Context, db Handle, records ...R) error { for idx := range records { v := reflect.ValueOf(&records[idx]).Elem() + err := checkTransparentPointerStructFieldsInitialized("UPDATE", idx, v, s.plan, false) + if err != nil { + return newIOError(err, "Stmt.Close", stmt.Close()) + } rowsAffected, err := updateRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots) if err == nil && rowsAffected == 0 { err = MissingRecordError[R]{records[idx], s.plan} @@ -134,7 +162,6 @@ func (s Store[R]) Update(ctx context.Context, db Handle, records ...R) error { } func updateRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any) (int64, error) { - // TODO: check plan.IndexesOfTransparentPointerStructs, return error (instead of panicking on FieldByIndex) if pointer struct is nil for idx, index := range argumentIndexes { argumentSlots[idx] = v.FieldByIndex(index).Interface() } @@ -168,7 +195,7 @@ func (s Store[R]) Delete(ctx context.Context, db Handle, records ...R) error { for idx := range records { v := reflect.ValueOf(&records[idx]).Elem() - err := deleteRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots) + err := deleteRecord(ctx, s.plan, v, idx, stmt, argumentIndexes, argumentSlots) if err != nil { return newIOError(err, "Stmt.Close", stmt.Close()) } @@ -177,13 +204,15 @@ func (s Store[R]) Delete(ctx context.Context, db Handle, records ...R) error { return newIOError(nil, "Stmt.Close", stmt.Close()) } -func deleteRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any) error { - // TODO: consider checking plan.IndexesOfTransparentPointerStructs and returning an error (instead of panicking on FieldByIndex) if pointer struct is nil - // (might want to have bookkeeping to only check pointer structs that lead to PK fields) +func deleteRecord(ctx context.Context, plan plan, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any) error { + err := checkTransparentPointerStructFieldsInitialized("DELETE", recordIndex, v, plan, true) + if err != nil { + return newIOError(err, "Stmt.Close", stmt.Close()) + } for idx, index := range argumentIndexes { argumentSlots[idx] = v.FieldByIndex(index).Interface() } - _, err := stmt.Exec(ctx, argumentSlots) + _, err = stmt.Exec(ctx, argumentSlots) if err != nil { return fmt.Errorf("while deleting record with idx = %d: %w", recordIndex, err) } @@ -217,9 +246,19 @@ func (s Store[R]) Upsert(ctx context.Context, db Handle, records ...*R) error { } updateStmt, err := prepare(ctx, db, s.plan.Update.Query, "Update", 0) if err != nil { - return err + return newIOError(err, "InsertStmt.Close", insertStmt.Close()) } + err = s.doUpsert(ctx, db, insertStmt, updateStmt, records) + err = newIOError(err, "InsertStmt.Close", insertStmt.Close()) + err = newIOError(err, "UpdateStmt.Close", updateStmt.Close()) + return err +} + +func (s Store[R]) doUpsert(ctx context.Context, db Handle, insertStmt, updateStmt handle.Statement, records []*R) error { + // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. + // Any expression that does not depend on type R should be factored out into a reusable function. + var ( insertArgumentIndexes = s.plan.Insert.ArgumentIndexes insertArgumentSlots = make([]any, len(insertArgumentIndexes)) @@ -231,6 +270,10 @@ func (s Store[R]) Upsert(ctx context.Context, db Handle, records ...*R) error { for idx, r := range records { v := reflect.ValueOf(r).Elem() + err := checkTransparentPointerStructFieldsInitialized("INSERT or UPDATE", idx, v, s.plan, false) + if err != nil { + return err + } isInsert, err := upsertDecideStrategy(v, idx, insertScanIndexes) if err != nil { return err @@ -246,15 +289,11 @@ func (s Store[R]) Upsert(ctx context.Context, db Handle, records ...*R) error { } } if err != nil { - err = newIOError(err, "InsertStmt.Close", insertStmt.Close()) - err = newIOError(err, "UpdateStmt.Close", updateStmt.Close()) return err } } - err = newIOError(err, "InsertStmt.Close", insertStmt.Close()) - err = newIOError(err, "UpdateStmt.Close", updateStmt.Close()) - return err + return nil } func upsertDecideStrategy(v reflect.Value, recordIndex int, scanIndexes [][]int) (isInsert bool, err error) { diff --git a/query_test.go b/query_test.go index 41b0203..382a463 100644 --- a/query_test.go +++ b/query_test.go @@ -346,3 +346,76 @@ func TestUpsertFailsOnMixedAutoFieldState(t *testing.T) { err := store.Upsert(ctx, db, &brokenRecord) assert.ErrEqual(t, err, `cannot decide whether to INSERT or UPDATE record with idx = 0: some "auto" columns are zero, others are not`) } + +func TestUninitializedTransparentPointerStructs(t *testing.T) { + ctx := t.Context() + md := mock.NewDriver() + db := oblast.Wrap(sql.OpenDB(md)) + + // declare a record type that has a transparent pointer struct containing non-primary-key fields + type timestamps struct { + CreatedAt time.Time `db:"created_at"` + DeletedAt *time.Time `db:"deleted_at"` + } + type nestedRecord struct { + ID int64 `db:"id,auto"` + Name string `db:"name"` + *timestamps + } + nestedRecordStore := oblast.MustNewStore[nestedRecord]( + oblast.SqliteDialect(), + oblast.TableNameIs("nested_records"), + oblast.PrimaryKeyIs("id"), + ) + + // declare another record type that has a primary key field within a transparent pointer struct + type commonFields struct { + ID int64 `db:"id,auto"` + CreatedAt time.Time `db:"created_at"` + DeletedAt *time.Time `db:"deleted_at"` + } + type weirdRecord struct { + *commonFields + Name string `db:"name"` + } + weirdRecordStore := oblast.MustNewStore[weirdRecord]( + oblast.SqliteDialect(), + oblast.TableNameIs("weird_records"), + oblast.PrimaryKeyIs("id"), + ) + + // check detection on INSERT + freshBrokenRecord := nestedRecord{ + Name: "foo", + timestamps: nil, // problem: cannot access `freshBrokenRecord.CreatedAt` or `freshBrokenRecord.DeletedAt` + } + err := nestedRecordStore.Insert(ctx, db, &freshBrokenRecord) + assert.ErrEqual(t, err, `refusing to INSERT record with idx = 0: cannot access all mapped fields because field "timestamps" holds a nil pointer`) + err = nestedRecordStore.Upsert(ctx, db, &freshBrokenRecord) + assert.ErrEqual(t, err, `refusing to INSERT or UPDATE record with idx = 0: cannot access all mapped fields because field "timestamps" holds a nil pointer`) + + // check detection on UPDATE + existingBrokenRecord := nestedRecord{ + ID: 42, + Name: "bar", + timestamps: nil, // same problem as above + } + err = nestedRecordStore.Update(ctx, db, existingBrokenRecord) + assert.ErrEqual(t, err, `refusing to UPDATE record with idx = 0: cannot access all mapped fields because field "timestamps" holds a nil pointer`) + err = nestedRecordStore.Upsert(ctx, db, &freshBrokenRecord) + assert.ErrEqual(t, err, `refusing to INSERT or UPDATE record with idx = 0: cannot access all mapped fields because field "timestamps" holds a nil pointer`) + + // check that detection on DELETE does not care about transparent pointer structs as long as they do not contain PK fields + md.ForQuery(`DELETE FROM "nested_records" WHERE "id" = ?`). + ExpectExecWithArgs(42). + AndReturnRowsAffected(1) + must.Succeed(t, nestedRecordStore.Delete(ctx, db, existingBrokenRecord)) + + // check detection on DELETE where it matters + existingWeirdRecord := weirdRecord{ + commonFields: nil, // problem: cannot access `existingWeirdRecord.ID` + Name: "qux", + } + err = weirdRecordStore.Delete(ctx, db, existingWeirdRecord) + assert.ErrEqual(t, err, `refusing to DELETE record with idx = 0: cannot access all primary key fields because field "commonFields" holds a nil pointer`) +} diff --git a/select.go b/select.go index 87dd36b..17195d0 100644 --- a/select.go +++ b/select.go @@ -144,8 +144,8 @@ func startSelectWhereQuery(ctx context.Context, db Handle, plan plan, partialQue } func collectRow(rows handle.Rows, plan plan, v reflect.Value, slots []any, indexes [][]int) error { - for _, index := range plan.IndexesOfTransparentPointerStructs { - f := v.FieldByIndex(index) + for _, field := range plan.TransparentPointerStructFields { + f := v.FieldByIndex(field.Index) f.Set(reflect.New(f.Type().Elem())) } for idx, index := range indexes { @@ -232,8 +232,8 @@ func selectOneWhere(ctx context.Context, db Handle, plan plan, v reflect.Value, } func selectOne(ctx context.Context, db Handle, plan plan, v reflect.Value, query string, args []any) error { - for _, index := range plan.IndexesOfTransparentPointerStructs { - f := v.FieldByIndex(index) + for _, field := range plan.TransparentPointerStructFields { + f := v.FieldByIndex(field.Index) f.Set(reflect.New(f.Type().Elem())) } slots := make([]any, len(plan.Select.ScanIndexes)) -- cgit v1.2.3