diff options
Diffstat (limited to 'query.go')
| -rw-r--r-- | query.go | 37 |
1 files changed, 30 insertions, 7 deletions
@@ -5,6 +5,7 @@ package oblast import ( "context" + "database/sql" "fmt" "reflect" @@ -72,7 +73,7 @@ func (s Store[R]) insertUsing(ctx context.Context, stmt handle.Statement, db Han if err != nil { return newIOError(err, "Stmt.Close", stmt.Close()) } - err = insertRecord(ctx, v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots) + err = insertRecord(ctx, s.plan, v, idx, stmt, argumentIndexes, argumentSlots, scanIndexes, scanSlots) if err != nil { return newIOError(err, "Stmt.Close", stmt.Close()) } @@ -81,7 +82,7 @@ func (s Store[R]) insertUsing(ctx context.Context, stmt handle.Statement, db Han return newIOError(nil, "Stmt.Close", stmt.Close()) } -func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any, scanIndexes [][]int, scanSlots []any) error { +func insertRecord(ctx context.Context, plan plan, v reflect.Value, recordIndex int, stmt handle.Statement, argumentIndexes [][]int, argumentSlots []any, scanIndexes [][]int, scanSlots []any) error { for idx, index := range argumentIndexes { argumentSlots[idx] = v.FieldByIndex(index).Interface() } @@ -92,16 +93,38 @@ func insertRecord(ctx context.Context, v reflect.Value, recordIndex int, stmt ha } scanSlots[idx] = f.Addr().Interface() } - var err error - if len(scanSlots) == 0 { + + var ( + result sql.Result + err error + ) + switch { + case len(scanSlots) == 0: _, err = stmt.Exec(ctx, argumentSlots) - } else { - // TODO: using QueryRow for inserting is extremely expensive because database/sql allocates a Rows instance under the hood; other libraries are doing better by limiting themselves to ExecContext() + LastInsertId() + case plan.InsertUsesQueryRow: err = stmt.QueryRow(ctx, argumentSlots, scanSlots) + default: + result, err = stmt.Exec(ctx, argumentSlots) } if err != nil { return fmt.Errorf("while inserting record with idx = %d: %w", recordIndex, err) } + + if result != nil { + id, err := result.LastInsertId() + if err != nil { + return fmt.Errorf("while getting LastInsertId for record with idx = %d: %w", recordIndex, err) + } + if plan.LastInsertIdIsUnsigned { + if id < 0 { + return fmt.Errorf("LastInsertId() = %d for record with idx = %d cannot be converted to uint", id, recordIndex) + } + v.FieldByIndex(scanIndexes[0]).SetUint(uint64(id)) + } else { + v.FieldByIndex(scanIndexes[0]).SetInt(id) + } + } + return nil } @@ -280,7 +303,7 @@ func (s Store[R]) doUpsert(ctx context.Context, db Handle, insertStmt, updateStm } if isInsert { - err = insertRecord(ctx, v, idx, insertStmt, insertArgumentIndexes, insertArgumentSlots, insertScanIndexes, insertScanSlots) + err = insertRecord(ctx, s.plan, v, idx, insertStmt, insertArgumentIndexes, insertArgumentSlots, insertScanIndexes, insertScanSlots) } else { var rowsAffected int64 rowsAffected, err = updateRecord(ctx, v, idx, updateStmt, updateArgumentIndexes, updateArgumentSlots) |
