From 9b5b72a549643a9e611f55ae8154fa801c808e5b Mon Sep 17 00:00:00 2001 From: Stefan Majewsky Date: Sun, 12 Apr 2026 17:59:16 +0200 Subject: add Store.SelectWhere, Store.SelectOneWhere --- benchmark/benchmark_test.go | 52 ++++++++++++++++++++++++------ internal/plan.go | 24 ++++++++++++++ internal/plan_test.go | 16 ++++++++++ select.go | 77 +++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 158 insertions(+), 11 deletions(-) diff --git a/benchmark/benchmark_test.go b/benchmark/benchmark_test.go index b60951a..edebc52 100644 --- a/benchmark/benchmark_test.go +++ b/benchmark/benchmark_test.go @@ -48,7 +48,7 @@ func makeTestDB(t testing.TB) (*sql.DB, error) { return db, nil } -func BenchmarkSelect(b *testing.B) { +func BenchmarkSelectMany(b *testing.B) { db, err := makeTestDB(b) if err != nil { b.Fatal(err) @@ -63,12 +63,16 @@ func BenchmarkSelect(b *testing.B) { ID int `db:"id"` Message string `db:"message"` } - store, err := oblast.NewStore[record](oblast.SqliteDialect()) + store, err := oblast.NewStore[record]( + oblast.SqliteDialect(), + oblast.TableNameIs("entries"), + ) if err != nil { b.Fatal(err) } gdb := gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}} - query := `SELECT * FROM entries WHERE id < ` + strconv.Itoa(selectedRecordCount) //nolint:gosec + partialQuery := `id < ` + strconv.Itoa(selectedRecordCount) + query := `SELECT * FROM entries WHERE ` + partialQuery //nolint:gosec selectWithOblast := func(b *testing.B) { records, err := store.Select(db, query) @@ -78,6 +82,14 @@ func BenchmarkSelect(b *testing.B) { assert.Equal(b, len(records), selectedRecordCount) } + selectWithOblastWhere := func(b *testing.B) { + records, err := store.SelectWhere(db, partialQuery) + if err != nil { + b.Error(err) + } + assert.Equal(b, len(records), selectedRecordCount) + } + selectWithGorp := func(b *testing.B) { var records []record _, err := gdb.Select(&records, query) @@ -121,16 +133,21 @@ func BenchmarkSelect(b *testing.B) { } // run actual benchmark - b.Run("via Gorp", func(b *testing.B) { + b.Run("via Gorp using Select", func(b *testing.B) { for range b.N { selectWithGorp(b) } }) - b.Run("via Oblast", func(b *testing.B) { + b.Run("via Oblast using Select", func(b *testing.B) { for range b.N { selectWithOblast(b) } }) + b.Run("via Oblast using SelectWhere", func(b *testing.B) { + for range b.N { + selectWithOblastWhere(b) + } + }) b.Run("just SQLite", func(b *testing.B) { for range b.N { selectWithSqlite(b) @@ -154,12 +171,16 @@ func BenchmarkSelectOne(b *testing.B) { ID int `db:"id"` Message string `db:"message"` } - store, err := oblast.NewStore[record](oblast.SqliteDialect()) + store, err := oblast.NewStore[record]( + oblast.SqliteDialect(), + oblast.TableNameIs("entries"), + ) if err != nil { b.Fatal(err) } gdb := gorp.DbMap{Db: db, Dialect: gorp.SqliteDialect{}} - query := `SELECT * FROM entries WHERE id = ` + strconv.Itoa(recordID) + partialQuery := `id = ` + strconv.Itoa(recordID) + query := `SELECT * FROM entries WHERE ` + partialQuery selectWithOblast := func(b *testing.B) { r, err := store.SelectOne(db, query) @@ -169,6 +190,14 @@ func BenchmarkSelectOne(b *testing.B) { assert.Equal(b, r.ID, recordID) } + selectWithOblastWhere := func(b *testing.B) { + r, err := store.SelectOneWhere(db, partialQuery) + if err != nil { + b.Error(err) + } + assert.Equal(b, r.ID, recordID) + } + selectWithGorp := func(b *testing.B) { var r record err := gdb.SelectOne(&r, query) @@ -198,16 +227,21 @@ func BenchmarkSelectOne(b *testing.B) { } // run actual benchmark - b.Run("via Gorp", func(b *testing.B) { + b.Run("via Gorp using SelectOne", func(b *testing.B) { for range b.N { selectWithGorp(b) } }) - b.Run("via Oblast", func(b *testing.B) { + b.Run("via Oblast using SelectOne", func(b *testing.B) { for range b.N { selectWithOblast(b) } }) + b.Run("via Oblast using SelectOneWhere", func(b *testing.B) { + for range b.N { + selectWithOblastWhere(b) + } + }) b.Run("just SQLite", func(b *testing.B) { for range b.N { selectWithSqlite(b) diff --git a/internal/plan.go b/internal/plan.go index ac199bf..7dc3361 100644 --- a/internal/plan.go +++ b/internal/plan.go @@ -23,6 +23,7 @@ type Plan struct { IndexByColumnName map[string][]int // Planned queries. + Select PlannedQuery // only `SELECT ... FROM ...` without WHERE or any of the other clauses Insert PlannedQuery Update PlannedQuery Delete PlannedQuery @@ -126,6 +127,7 @@ func buildPlan(t reflect.Type, dialect Dialect, opts PlanOpts) (Plan, error) { } // prepare query strings + p.Select = p.buildSelectQueryIfPossible(dialect) p.Insert = p.buildInsertQueryIfPossible(dialect) p.Update = p.buildUpdateQueryIfPossible(dialect) p.Delete = p.buildDeleteQueryIfPossible(dialect) @@ -153,6 +155,28 @@ func (p Plan) getNonPrimaryKeyColumnNames() []string { return result } +func (p Plan) buildSelectQueryIfPossible(dialect Dialect) PlannedQuery { + if p.TableName == "" { + return PlannedQuery{Query: ""} + } + + var ( + argumentIndexes = make([][]int, len(p.AllColumnNames)) + quotedColumnNames = make([]string, len(p.AllColumnNames)) + ) + for idx, columnName := range p.AllColumnNames { + argumentIndexes[idx] = p.IndexByColumnName[columnName] + quotedColumnNames[idx] = dialect.QuoteIdentifier(columnName) + } + + query := fmt.Sprintf( + `SELECT %s FROM %s WHERE `, + strings.Join(quotedColumnNames, ", "), + dialect.QuoteIdentifier(p.TableName), + ) + return PlannedQuery{query, argumentIndexes} +} + func (p Plan) buildInsertQueryIfPossible(dialect Dialect) PlannedQuery { if p.TableName == "" || len(p.AllColumnNames) == 0 { return PlannedQuery{Query: ""} diff --git a/internal/plan_test.go b/internal/plan_test.go index 88afedc..db12943 100644 --- a/internal/plan_test.go +++ b/internal/plan_test.go @@ -75,6 +75,8 @@ func TestQueryConstructionBasic(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}, {2}}) assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES ($1, $2) RETURNING "ID"`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = $1, "CreatedAt" = $2 WHERE "ID" = $3`) @@ -88,6 +90,8 @@ func TestQueryConstructionBasic(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.Select.Query, `SELECT "ID", "Description", "CreatedAt" FROM "basic_records" WHERE `) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}, {2}}) assert.Equal(t, plan.Insert.Query, `INSERT INTO "basic_records" ("Description", "CreatedAt") VALUES (?, ?)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}, {2}}) assert.Equal(t, plan.Update.Query, `UPDATE "basic_records" SET "Description" = ?, "CreatedAt" = ? WHERE "ID" = ?`) @@ -111,6 +115,8 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}}) assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES ($1, $2)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) assert.Equal(t, plan.Update.Query, "") @@ -124,6 +130,8 @@ func TestQueryConstructionWithoutPrimaryKey(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.Select.Query, `SELECT "foo_id", "bar_id" FROM "foo_bar_relations" WHERE `) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}}) assert.Equal(t, plan.Insert.Query, `INSERT INTO "foo_bar_relations" ("foo_id", "bar_id") VALUES (?, ?)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}}) assert.Equal(t, plan.Update.Query, "") @@ -147,6 +155,8 @@ func TestQueryConstructionImpossble(t *testing.T) { t.Error(err) } + assert.Equal(t, plan.Select.Query, "") + assert.DeepEqual(t, plan.Select.ArgumentIndexes, nil) assert.Equal(t, plan.Insert.Query, "") assert.DeepEqual(t, plan.Insert.ArgumentIndexes, nil) assert.Equal(t, plan.Update.Query, "") @@ -176,6 +186,8 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}, {2}}) assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES ($1, $2, $3)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = $1 WHERE "group_id" = $2 AND "name" = $3`) @@ -189,6 +201,8 @@ func TestQueryConstructionWithMultiplePrimaryKeyColumns(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.Select.Query, `SELECT "group_id", "name", "created_at" FROM "complex_records" WHERE `) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}, {2}}) assert.Equal(t, plan.Insert.Query, `INSERT INTO "complex_records" ("group_id", "name", "created_at") VALUES (?, ?, ?)`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{0}, {1}, {2}}) assert.Equal(t, plan.Update.Query, `UPDATE "complex_records" SET "created_at" = ? WHERE "group_id" = ? AND "name" = ?`) @@ -214,6 +228,8 @@ func TestQueryConstructionWithMultipleAutoColumns(t *testing.T) { if err != nil { t.Error(err) } + assert.Equal(t, plan.Select.Query, `SELECT "id", "name", "created_at" FROM "autogenerated_records" WHERE `) + assert.DeepEqual(t, plan.Select.ArgumentIndexes, [][]int{{0}, {1}, {2}}) assert.Equal(t, plan.Insert.Query, `INSERT INTO "autogenerated_records" ("name") VALUES ($1) RETURNING "id", "created_at"`) assert.DeepEqual(t, plan.Insert.ArgumentIndexes, [][]int{{1}}) assert.Equal(t, plan.Update.Query, `UPDATE "autogenerated_records" SET "name" = $1, "created_at" = $2 WHERE "id" = $3`) diff --git a/select.go b/select.go index 23521ed..6616434 100644 --- a/select.go +++ b/select.go @@ -5,6 +5,7 @@ package oblast import ( "database/sql" + "errors" "fmt" "reflect" @@ -19,7 +20,7 @@ func (s Store[R]) Select(db Handle, query string, args ...any) (result []R, retu // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. // Any expression that does not depend on type R should be factored out into a reusable function. - rows, indexes, err := startQuery(db, s.plan, query, args...) + rows, indexes, err := startSelectQuery(db, s.plan, query, args...) if err != nil { return nil, err } @@ -40,7 +41,45 @@ func (s Store[R]) Select(db Handle, query string, args ...any) (result []R, retu return result, nil } -func startQuery(db Handle, plan internal.Plan, query string, args ...any) (rows *sql.Rows, indexes [][]int, err error) { +// SelectWhere is like [Store.Select], but you only provide the part of the SELECT query that comes after the WHERE. +// The initial part ("SELECT ... FROM ... WHERE") is autogenerated and prepended to partialQuery. +// This has two benefits: +// - It is more efficient because the strategy for loading result rows into the record type R has already been precomputed during [NewStore], +// whereas a regular [Store.Select] must inspect the column names in the result set for each [Store.Select] call. +// - For record types that contain only some of the columns of the corresponding database table, +// the autogenerated SELECT query will only load exactly the necessary fields and nothing else. +// +// partialQuery is implied to start right after the WHERE keyword, which is added automatically. +// To select all records unconditionally, provide a partialQuery of "TRUE", leading to a full query of "SELECT ... FROM ... WHERE TRUE". +// Besides a condition for the WHERE clause, it may contain additional clauses, such as ORDER BY or LIMIT. +// +// Returns an error if [NewStore] was called without the [TableNameIs] option, which is required to generate a query for this method. +func (s Store[R]) SelectWhere(db Handle, partialQuery string, args ...any) (result []R, returnedError error) { + // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. + // Any expression that does not depend on type R should be factored out into a reusable function. + + rows, indexes, err := startSelectWhereQuery(db, s.plan, partialQuery, args...) + if err != nil { + return nil, err + } + defer func() { + returnedError = mergeRowsCloseError(returnedError, rows.Close()) + }() + + slots := make([]any, len(indexes)) + for rows.Next() { + var target R + err = collectRow(rows, reflect.ValueOf(&target).Elem(), slots, indexes) + if err != nil { + return nil, err + } + result = append(result, target) + } + + return result, nil +} + +func startSelectQuery(db Handle, plan internal.Plan, query string, args ...any) (rows *sql.Rows, indexes [][]int, err error) { rows, err = db.Query(query, args...) if err != nil { return nil, nil, fmt.Errorf("during Query(): %w", err) @@ -73,6 +112,15 @@ func startQuery(db Handle, plan internal.Plan, query string, args ...any) (rows return rows, indexes, nil } +func startSelectWhereQuery(db Handle, plan internal.Plan, partialQuery string, args ...any) (rows *sql.Rows, indexes [][]int, err error) { + if plan.Select.Query == "" { + return nil, nil, errors.New("cannot execute SelectWhere() because SELECT query could not be autogenerated") + } + query := plan.Select.Query + partialQuery + rows, err = db.Query(query, args...) + return rows, plan.Select.ArgumentIndexes, err +} + func collectRow(rows *sql.Rows, v reflect.Value, slots []any, indexes [][]int) error { for idx, index := range indexes { slots[idx] = v.FieldByIndex(index).Addr().Interface() @@ -106,6 +154,7 @@ func mergeRowsCloseError(err, closeErr error) error { func (s Store[R]) SelectOne(db Handle, query string, args ...any) (result R, err error) { // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. // Any expression that does not depend on type R should be factored out into a reusable function. + var results []R results, err = s.Select(db, query, args...) if err == nil { @@ -120,3 +169,27 @@ func (s Store[R]) SelectOne(db Handle, query string, args ...any) (result R, err } return } + +// SelectOneWhere is like [Store.SelectOne], but you only provide the part of the SELECT query that comes after the WHERE. +// See [Store.SelectWhere] for an explanation of how the full query is constructed from this partial query. +// +// This method is significantly more efficient than [Store.SelectWhere] and using it is recommended when possible. +func (s Store[R]) SelectOneWhere(db Handle, partialQuery string, args ...any) (result R, err error) { + // NOTE: This function body should be as short as possible to reduce the binary size after monomorphization. + // Any expression that does not depend on type R should be factored out into a reusable function. + + err = selectOneWhere(db, s.plan, reflect.ValueOf(&result).Elem(), partialQuery, args) + return +} + +func selectOneWhere(db Handle, plan internal.Plan, v reflect.Value, partialQuery string, args []any) error { + if plan.Select.Query == "" { + return errors.New("cannot execute SelectOneWhere() because SELECT query could not be autogenerated") + } + query := plan.Select.Query + partialQuery + slots := make([]any, len(plan.Select.ArgumentIndexes)) + for idx, index := range plan.Select.ArgumentIndexes { + slots[idx] = v.FieldByIndex(index).Addr().Interface() + } + return db.QueryRow(query, args...).Scan(slots...) +} -- cgit v1.2.3