aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/mock/mock.go12
-rw-r--r--query.go9
-rw-r--r--query_test.go137
3 files changed, 155 insertions, 3 deletions
diff --git a/internal/mock/mock.go b/internal/mock/mock.go
index ecbb03e..6265166 100644
--- a/internal/mock/mock.go
+++ b/internal/mock/mock.go
@@ -158,7 +158,17 @@ func (s *statement) Close() error {
// NumInput implements the [driver.Stmt] interface.
func (s *statement) NumInput() int {
- return strings.Count(s.query, "?") // NOTE: extremely crude, but does the job for us
+ // option 1: when using SQLite dialect, count `?`
+ count := strings.Count(s.query, "?")
+ if count > 0 {
+ return count
+ }
+
+ // option 2: when using PostgreSQL dialect, find `$1`, `$2`, etc.
+ for strings.Contains(s.query, fmt.Sprintf("$%d", count+1)) {
+ count++
+ }
+ return count
}
// Exec implements the [driver.Stmt] interface.
diff --git a/query.go b/query.go
index 1b9ae46..7810bac 100644
--- a/query.go
+++ b/query.go
@@ -63,7 +63,7 @@ func (s Store[R]) Insert(db Handle, records ...R) (returnedRecords []R, returned
for idx := range records {
v := reflect.ValueOf(&records[idx]).Elem()
for idx, index := range argumentIndexes {
- argumentSlots[idx] = v.FieldByIndex(index).Addr().Interface()
+ argumentSlots[idx] = v.FieldByIndex(index).Interface()
}
if s.dialect.UsesLastInsertID() {
@@ -95,7 +95,12 @@ func (s Store[R]) Insert(db Handle, records ...R) (returnedRecords []R, returned
for idx, index := range scanIndexes {
scanSlots[idx] = v.FieldByIndex(index).Addr().Interface()
}
- err := stmt.QueryRow(argumentSlots...).Scan(scanSlots...)
+ var err error
+ if stmt == nil {
+ err = db.QueryRow(s.plan.Insert.Query, argumentSlots...).Scan(scanSlots...)
+ } else {
+ err = stmt.QueryRow(argumentSlots...).Scan(scanSlots...)
+ }
if err != nil {
return nil, fmt.Errorf("during QueryRow() for record with idx = %d: %w", idx, err)
}
diff --git a/query_test.go b/query_test.go
new file mode 100644
index 0000000..ec1d8d8
--- /dev/null
+++ b/query_test.go
@@ -0,0 +1,137 @@
+// SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net>
+// SPDX-License-Identifier: Apache-2.0
+
+package oblast_test
+
+import (
+ "database/sql"
+ "strconv"
+ "testing"
+
+ "go.xyrillian.de/oblast"
+ "go.xyrillian.de/oblast/internal/assert"
+ "go.xyrillian.de/oblast/internal/mock"
+ "go.xyrillian.de/oblast/internal/must"
+)
+
+func TestInsertBasicUsingLastInsertId(t *testing.T) {
+ md := mock.NewDriver()
+ db := sql.OpenDB(md)
+
+ type basicRecord struct {
+ ID int64 `db:"id,auto"`
+ Name string `db:"name"`
+ }
+ store := oblast.MustNewStore[basicRecord](
+ oblast.SqliteDialect(),
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} {
+ t.Run("N="+strconv.Itoa(batchSize), func(t *testing.T) {
+ records := make([]basicRecord, batchSize)
+ for idx := range batchSize {
+ records[idx] = basicRecord{Name: "new"}
+ md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES (?)`).
+ ExpectExecWithArgs("new").
+ AndReturnLastInsertId(int64(42 + idx)).
+ AndReturnRowsAffected(1)
+ }
+ records = must.Return(store.Insert(db, records...))(t)
+ for idx, r := range records {
+ assert.Equal(t, r.ID, int64(42+idx))
+ }
+ })
+ }
+}
+
+func TestInsertBasicUsingReturningClause(t *testing.T) {
+ md := mock.NewDriver()
+ db := sql.OpenDB(md)
+
+ type basicRecord struct {
+ ID int64 `db:"id,auto"`
+ Name string `db:"name"`
+ }
+ store := oblast.MustNewStore[basicRecord](
+ oblast.PostgresDialect(),
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} {
+ t.Run("N="+strconv.Itoa(batchSize), func(t *testing.T) {
+ records := make([]basicRecord, batchSize)
+ for idx := range batchSize {
+ records[idx] = basicRecord{Name: "new"}
+ md.ForQuery(`INSERT INTO "basic_records" ("name") VALUES ($1) RETURNING "id"`).
+ ExpectQueryWithArgs("new").
+ AndReturnColumns("id").
+ WithRow(int64(42 + idx))
+ }
+ records = must.Return(store.Insert(db, records...))(t)
+ for idx, r := range records {
+ assert.Equal(t, r.ID, int64(42+idx))
+ }
+ })
+ }
+}
+
+func TestUpdateBasic(t *testing.T) {
+ md := mock.NewDriver()
+ db := sql.OpenDB(md)
+
+ type basicRecord struct {
+ ID int64 `db:"id,auto"`
+ Name string `db:"name"`
+ }
+ store := oblast.MustNewStore[basicRecord](
+ oblast.PostgresDialect(),
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} {
+ t.Run("N="+strconv.Itoa(batchSize), func(t *testing.T) {
+ records := make([]basicRecord, batchSize)
+ for idx := range batchSize {
+ r := basicRecord{ID: int64(42 + idx), Name: "updated"}
+ records[idx] = r
+ md.ForQuery(`UPDATE "basic_records" SET "name" = $1 WHERE "id" = $2`).
+ ExpectExecWithArgs(r.Name, r.ID).
+ AndReturnRowsAffected(1)
+ }
+ must.Succeed(t, store.Update(db, records...))
+ })
+ }
+}
+
+func TestDeleteBasic(t *testing.T) {
+ md := mock.NewDriver()
+ db := sql.OpenDB(md)
+
+ type basicRecord struct {
+ ID int64 `db:"id,auto"`
+ Name string `db:"name"`
+ }
+ store := oblast.MustNewStore[basicRecord](
+ oblast.PostgresDialect(),
+ oblast.TableNameIs("basic_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ for _, batchSize := range []int{1, oblast.PrepareThreshold - 1, oblast.PrepareThreshold + 1} {
+ t.Run("N="+strconv.Itoa(batchSize), func(t *testing.T) {
+ records := make([]basicRecord, batchSize)
+ for idx := range batchSize {
+ r := basicRecord{ID: int64(42 + idx), Name: "removed"}
+ records[idx] = r
+ md.ForQuery(`DELETE FROM "basic_records" WHERE "id" = $1`).
+ ExpectExecWithArgs(r.ID).
+ AndReturnRowsAffected(1)
+ }
+ must.Succeed(t, store.Delete(db, records...))
+ })
+ }
+}