diff options
| author | Stefan Majewsky <majewsky@gmx.net> | 2026-05-13 00:39:22 +0200 |
|---|---|---|
| committer | Stefan Majewsky <majewsky@gmx.net> | 2026-05-13 00:40:24 +0200 |
| commit | 2fe6a5a42ccb663211f4f4804b78fff3bd9ebdc0 (patch) | |
| tree | 08fe0bcc17dcd617d08a14847375710b80d86d8d /query.go | |
| parent | a86a346ecceb7ad409f116474c1593b201012cf2 (diff) | |
| download | go-oblast-2fe6a5a42ccb663211f4f4804b78fff3bd9ebdc0.tar.gz | |
Insert, Upsert, Update, Delete: do not panic on indirection through nil pointer
Diffstat (limited to 'query.go')
| -rw-r--r-- | query.go | 67 |
1 files changed, 53 insertions, 14 deletions
@@ -68,7 +68,11 @@ func (s Store[R]) insertUsing(ctx context.Context, stmt handle.Statement, db Han for idx, r := range records { v := reflect.ValueOf(r).Elem() - err := insertRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots) + err := checkTransparentPointerStructFieldsInitialized("INSERT", idx, v, s.plan, false) + if err != nil { + return newIOError(err, "Stmt.Close", stmt.Close()) + } + err = insertRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots) if err != nil { return newIOError(err, "Stmt.Close", stmt.Close()) } @@ -78,7 +82,6 @@ func (s Store[R]) insertUsing(ctx context.Context, stmt handle.Statement, db Han } func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any, scanIndexes [][]int, scanSlots []any) error { - // TODO: check plan.IndexesOfTransparentPointerStructs, return error (instead of panicking on FieldByIndex) if pointer struct is nil for idx, index := range argumentIndexes { argumentSlots[idx] = v.FieldByIndex(index).Interface() } @@ -102,6 +105,27 @@ func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt ha return nil } +// This check must be performed within all query functions that access existing values using FieldByIndex(), +// to ensure that FieldByIndex() does not panic on indirection through a nil pointer. +func checkTransparentPointerStructFieldsInitialized(operation string, recordIndex int, v reflect.Value, plan plan, onlyPK bool) error { + for _, field := range plan.TransparentPointerStructFields { + f := v.FieldByIndex(field.Index) + if !f.IsZero() { + continue + } + if onlyPK { + if field.ContainsPrimaryKey { + return fmt.Errorf(`refusing to %s record with idx = %d: cannot access all primary key fields because field %q holds a nil pointer`, + operation, recordIndex, field.Name) + } + } else { + return fmt.Errorf(`refusing to %s record with idx = %d: cannot access all mapped fields because field %q holds a nil pointer`, + operation, recordIndex, field.Name) + } + } + return nil +} + // Update executes an SQL UPDATE statement for each of the provided records, updating all non-primary-key columns with the values in the records. // Returns [MissingRecordError] if any of the records does not exist in the database, that is, if for any of the records, the database contains no row with the same primary key values. // @@ -122,6 +146,10 @@ func (s Store[R]) Update(ctx context.Context, db Handle, records ...R) error { for idx := range records { v := reflect.ValueOf(&records[idx]).Elem() + err := checkTransparentPointerStructFieldsInitialized("UPDATE", idx, v, s.plan, false) + if err != nil { + return newIOError(err, "Stmt.Close", stmt.Close()) + } rowsAffected, err := updateRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots) if err == nil && rowsAffected == 0 { err = MissingRecordError[R]{records[idx], s.plan} @@ -134,7 +162,6 @@ func (s Store[R]) Update(ctx context.Context, db Handle, records ...R) error { } func updateRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any) (int64, error) { - // TODO: check plan.IndexesOfTransparentPointerStructs, return error (instead of panicking on FieldByIndex) if pointer struct is nil for idx, index := range argumentIndexes { argumentSlots[idx] = v.FieldByIndex(index).Interface() } @@ -168,7 +195,7 @@ func (s Store[R]) Delete(ctx context.Context, db Handle, records ...R) error { for idx := range records { v := reflect.ValueOf(&records[idx]).Elem() - err := deleteRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots) + err := deleteRecord(ctx, s.plan, v, idx, stmt, argumentIndexes, argumentSlots) if err != nil { return newIOError(err, "Stmt.Close", stmt.Close()) } @@ -177,13 +204,15 @@ func (s Store[R]) Delete(ctx context.Context, db Handle, records ...R) error { return newIOError(nil, "Stmt.Close", stmt.Close()) } -func deleteRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any) error { - // TODO: consider checking plan.IndexesOfTransparentPointerStructs and returning an error (instead of panicking on FieldByIndex) if pointer struct is nil - // (might want to have bookkeeping to only check pointer structs that lead to PK fields) +func deleteRecord(ctx context.Context, plan plan, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any) error { + err := checkTransparentPointerStructFieldsInitialized("DELETE", recordIndex, v, plan, true) + if err != nil { + return newIOError(err, "Stmt.Close", stmt.Close()) + } for idx, index := range argumentIndexes { argumentSlots[idx] = v.FieldByIndex(index).Interface() } - _, err := stmt.Exec(ctx, argumentSlots) + _, err = stmt.Exec(ctx, argumentSlots) if err != nil { return fmt.Errorf("while deleting record with idx = %d: %w", recordIndex, err) } @@ -217,9 +246,19 @@ func (s Store[R]) Upsert(ctx context.Context, db Handle, records ...*R) error { } updateStmt, err := prepare(ctx, db, s.plan.Update.Query, "Update", 0) if err != nil { - return err + return newIOError(err, "InsertStmt.Close", insertStmt.Close()) } + err = s.doUpsert(ctx, db, insertStmt, updateStmt, records) + err = newIOError(err, "InsertStmt.Close", insertStmt.Close()) + err = newIOError(err, "UpdateStmt.Close", updateStmt.Close()) + return err +} + +func (s Store[R]) doUpsert(ctx context.Context, db Handle, insertStmt, updateStmt handle.Statement, records []*R) error { + // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. + // Any expression that does not depend on type R should be factored out into a reusable function. + var ( insertArgumentIndexes = s.plan.Insert.ArgumentIndexes insertArgumentSlots = make([]any, len(insertArgumentIndexes)) @@ -231,6 +270,10 @@ func (s Store[R]) Upsert(ctx context.Context, db Handle, records ...*R) error { for idx, r := range records { v := reflect.ValueOf(r).Elem() + err := checkTransparentPointerStructFieldsInitialized("INSERT or UPDATE", idx, v, s.plan, false) + if err != nil { + return err + } isInsert, err := upsertDecideStrategy(v, idx, insertScanIndexes) if err != nil { return err @@ -246,15 +289,11 @@ func (s Store[R]) Upsert(ctx context.Context, db Handle, records ...*R) error { } } if err != nil { - err = newIOError(err, "InsertStmt.Close", insertStmt.Close()) - err = newIOError(err, "UpdateStmt.Close", updateStmt.Close()) return err } } - err = newIOError(err, "InsertStmt.Close", insertStmt.Close()) - err = newIOError(err, "UpdateStmt.Close", updateStmt.Close()) - return err + return nil } func upsertDecideStrategy(v reflect.Value, recordIndex int, scanIndexes [][]int) (isInsert bool, err error) { |
