aboutsummaryrefslogtreecommitdiff
path: root/query.go
diff options
context:
space:
mode:
Diffstat (limited to 'query.go')
-rw-r--r--query.go67
1 files changed, 53 insertions, 14 deletions
diff --git a/query.go b/query.go
index de151a6..eea1771 100644
--- a/query.go
+++ b/query.go
@@ -68,7 +68,11 @@ func (s Store[R]) insertUsing(ctx context.Context, stmt handle.Statement, db Han
for idx, r := range records {
v := reflect.ValueOf(r).Elem()
- err := insertRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots)
+ err := checkTransparentPointerStructFieldsInitialized("INSERT", idx, v, s.plan, false)
+ if err != nil {
+ return newIOError(err, "Stmt.Close", stmt.Close())
+ }
+ err = insertRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots)
if err != nil {
return newIOError(err, "Stmt.Close", stmt.Close())
}
@@ -78,7 +82,6 @@ func (s Store[R]) insertUsing(ctx context.Context, stmt handle.Statement, db Han
}
func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any, scanIndexes [][]int, scanSlots []any) error {
- // TODO: check plan.IndexesOfTransparentPointerStructs, return error (instead of panicking on FieldByIndex) if pointer struct is nil
for idx, index := range argumentIndexes {
argumentSlots[idx] = v.FieldByIndex(index).Interface()
}
@@ -102,6 +105,27 @@ func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt ha
return nil
}
+// This check must be performed within all query functions that access existing values using FieldByIndex(),
+// to ensure that FieldByIndex() does not panic on indirection through a nil pointer.
+func checkTransparentPointerStructFieldsInitialized(operation string, recordIndex int, v reflect.Value, plan plan, onlyPK bool) error {
+ for _, field := range plan.TransparentPointerStructFields {
+ f := v.FieldByIndex(field.Index)
+ if !f.IsZero() {
+ continue
+ }
+ if onlyPK {
+ if field.ContainsPrimaryKey {
+ return fmt.Errorf(`refusing to %s record with idx = %d: cannot access all primary key fields because field %q holds a nil pointer`,
+ operation, recordIndex, field.Name)
+ }
+ } else {
+ return fmt.Errorf(`refusing to %s record with idx = %d: cannot access all mapped fields because field %q holds a nil pointer`,
+ operation, recordIndex, field.Name)
+ }
+ }
+ return nil
+}
+
// Update executes an SQL UPDATE statement for each of the provided records, updating all non-primary-key columns with the values in the records.
// Returns [MissingRecordError] if any of the records does not exist in the database, that is, if for any of the records, the database contains no row with the same primary key values.
//
@@ -122,6 +146,10 @@ func (s Store[R]) Update(ctx context.Context, db Handle, records ...R) error {
for idx := range records {
v := reflect.ValueOf(&records[idx]).Elem()
+ err := checkTransparentPointerStructFieldsInitialized("UPDATE", idx, v, s.plan, false)
+ if err != nil {
+ return newIOError(err, "Stmt.Close", stmt.Close())
+ }
rowsAffected, err := updateRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots)
if err == nil && rowsAffected == 0 {
err = MissingRecordError[R]{records[idx], s.plan}
@@ -134,7 +162,6 @@ func (s Store[R]) Update(ctx context.Context, db Handle, records ...R) error {
}
func updateRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any) (int64, error) {
- // TODO: check plan.IndexesOfTransparentPointerStructs, return error (instead of panicking on FieldByIndex) if pointer struct is nil
for idx, index := range argumentIndexes {
argumentSlots[idx] = v.FieldByIndex(index).Interface()
}
@@ -168,7 +195,7 @@ func (s Store[R]) Delete(ctx context.Context, db Handle, records ...R) error {
for idx := range records {
v := reflect.ValueOf(&records[idx]).Elem()
- err := deleteRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots)
+ err := deleteRecord(ctx, s.plan, v, idx, stmt, argumentIndexes, argumentSlots)
if err != nil {
return newIOError(err, "Stmt.Close", stmt.Close())
}
@@ -177,13 +204,15 @@ func (s Store[R]) Delete(ctx context.Context, db Handle, records ...R) error {
return newIOError(nil, "Stmt.Close", stmt.Close())
}
-func deleteRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any) error {
- // TODO: consider checking plan.IndexesOfTransparentPointerStructs and returning an error (instead of panicking on FieldByIndex) if pointer struct is nil
- // (might want to have bookkeeping to only check pointer structs that lead to PK fields)
+func deleteRecord(ctx context.Context, plan plan, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any) error {
+ err := checkTransparentPointerStructFieldsInitialized("DELETE", recordIndex, v, plan, true)
+ if err != nil {
+ return newIOError(err, "Stmt.Close", stmt.Close())
+ }
for idx, index := range argumentIndexes {
argumentSlots[idx] = v.FieldByIndex(index).Interface()
}
- _, err := stmt.Exec(ctx, argumentSlots)
+ _, err = stmt.Exec(ctx, argumentSlots)
if err != nil {
return fmt.Errorf("while deleting record with idx = %d: %w", recordIndex, err)
}
@@ -217,9 +246,19 @@ func (s Store[R]) Upsert(ctx context.Context, db Handle, records ...*R) error {
}
updateStmt, err := prepare(ctx, db, s.plan.Update.Query, "Update", 0)
if err != nil {
- return err
+ return newIOError(err, "InsertStmt.Close", insertStmt.Close())
}
+ err = s.doUpsert(ctx, db, insertStmt, updateStmt, records)
+ err = newIOError(err, "InsertStmt.Close", insertStmt.Close())
+ err = newIOError(err, "UpdateStmt.Close", updateStmt.Close())
+ return err
+}
+
+func (s Store[R]) doUpsert(ctx context.Context, db Handle, insertStmt, updateStmt handle.Statement, records []*R) error {
+ // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization.
+ // Any expression that does not depend on type R should be factored out into a reusable function.
+
var (
insertArgumentIndexes = s.plan.Insert.ArgumentIndexes
insertArgumentSlots = make([]any, len(insertArgumentIndexes))
@@ -231,6 +270,10 @@ func (s Store[R]) Upsert(ctx context.Context, db Handle, records ...*R) error {
for idx, r := range records {
v := reflect.ValueOf(r).Elem()
+ err := checkTransparentPointerStructFieldsInitialized("INSERT or UPDATE", idx, v, s.plan, false)
+ if err != nil {
+ return err
+ }
isInsert, err := upsertDecideStrategy(v, idx, insertScanIndexes)
if err != nil {
return err
@@ -246,15 +289,11 @@ func (s Store[R]) Upsert(ctx context.Context, db Handle, records ...*R) error {
}
}
if err != nil {
- err = newIOError(err, "InsertStmt.Close", insertStmt.Close())
- err = newIOError(err, "UpdateStmt.Close", updateStmt.Close())
return err
}
}
- err = newIOError(err, "InsertStmt.Close", insertStmt.Close())
- err = newIOError(err, "UpdateStmt.Close", updateStmt.Close())
- return err
+ return nil
}
func upsertDecideStrategy(v reflect.Value, recordIndex int, scanIndexes [][]int) (isInsert bool, err error) {