aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/assert/assert.go13
-rw-r--r--plan.go9
-rw-r--r--select.go14
-rw-r--r--select_test.go75
4 files changed, 106 insertions, 5 deletions
diff --git a/internal/assert/assert.go b/internal/assert/assert.go
index 84b6ecf..99af59c 100644
--- a/internal/assert/assert.go
+++ b/internal/assert/assert.go
@@ -36,3 +36,16 @@ func SliceEqual[V comparable](t testing.TB, actual []V, expected ...V) {
}
}
}
+
+// SliceDeepEqual is a test assertion.
+func SliceDeepEqual[V any](t testing.TB, actual []V, expected ...V) {
+ t.Helper()
+ if len(actual) != len(expected) {
+ t.Errorf("length mismatch: expected %#v, but got %#v", expected, actual)
+ }
+ for idx := range actual {
+ if !reflect.DeepEqual(actual[idx], expected[idx]) {
+ t.Errorf("element %d: expected %#v, but got %#v", idx, expected[idx], actual[idx])
+ }
+ }
+}
diff --git a/plan.go b/plan.go
index d00b5c2..245deab 100644
--- a/plan.go
+++ b/plan.go
@@ -21,8 +21,10 @@ type plan struct {
PrimaryKeyColumnNames []string // from info.PrimaryKeyIs marker (if any)
AutoColumnNames []string // subset of AllColumnNames where field has `,auto` marker
- // Argument for reflect.Value.FieldByIndex() for each column name.
+ // Field index (i.e. argument for reflect.Value.FieldByIndex()) for each column name.
IndexByColumnName map[string][]int
+ // Indexes of pointer-typed fields that need to be initialized before scanning into this type.
+ IndexesOfTransparentPointerStructs [][]int
// In dialects with UsesLastInsertID() == true, whether the ID column must be written with reflect.Value.SetInt() or reflect.Value.SetUint().
FillIDWithSetUint bool
@@ -85,6 +87,11 @@ func buildPlan(t reflect.Type, dialect Dialect, opts planOpts) (plan, error) {
if field.Type.Kind() == reflect.Struct || (field.Type.Kind() == reflect.Pointer && field.Type.Elem().Kind() == reflect.Struct) {
if field.Tag.Get("db") == "" {
indexesOfUnusedTransparentStructs = append(indexesOfUnusedTransparentStructs, field.Index)
+ if field.Type.Kind() == reflect.Pointer {
+ // remember that, when scanning into a record of type `t`, we need to write a non-nil zeroed struct into this field
+ // to enable taking an address of its mapped member fields
+ p.IndexesOfTransparentPointerStructs = append(p.IndexesOfTransparentPointerStructs, field.Index)
+ }
continue
}
indexesOfOpaqueStructs = append(indexesOfOpaqueStructs, field.Index)
diff --git a/select.go b/select.go
index 8aed249..9de6e13 100644
--- a/select.go
+++ b/select.go
@@ -29,7 +29,7 @@ func (s Store[R]) Select(db Handle, query string, args ...any) (result []R, retu
slots := make([]any, len(indexes))
for rows.Next() {
var target R
- err = collectRow(rows, reflect.ValueOf(&target).Elem(), slots, indexes)
+ err = collectRow(rows, s.plan, reflect.ValueOf(&target).Elem(), slots, indexes)
if err != nil {
return nil, err
}
@@ -67,7 +67,7 @@ func (s Store[R]) SelectWhere(db Handle, partialQuery string, args ...any) (resu
slots := make([]any, len(indexes))
for rows.Next() {
var target R
- err = collectRow(rows, reflect.ValueOf(&target).Elem(), slots, indexes)
+ err = collectRow(rows, s.plan, reflect.ValueOf(&target).Elem(), slots, indexes)
if err != nil {
return nil, err
}
@@ -119,7 +119,11 @@ func startSelectWhereQuery(db Handle, plan plan, partialQuery string, args ...an
return rows, plan.Select.ScanIndexes, err
}
-func collectRow(rows *sql.Rows, v reflect.Value, slots []any, indexes [][]int) error {
+func collectRow(rows *sql.Rows, plan plan, v reflect.Value, slots []any, indexes [][]int) error {
+ for _, index := range plan.IndexesOfTransparentPointerStructs {
+ f := v.FieldByIndex(index)
+ f.Set(reflect.New(f.Type().Elem()))
+ }
for idx, index := range indexes {
slots[idx] = v.FieldByIndex(index).Addr().Interface()
}
@@ -178,6 +182,10 @@ func selectOneWhere(db Handle, plan plan, v reflect.Value, partialQuery string,
return errors.New("cannot execute SelectOneWhere() because query could not be autogenerated")
}
query := plan.Select.Query + partialQuery
+ for _, index := range plan.IndexesOfTransparentPointerStructs {
+ f := v.FieldByIndex(index)
+ f.Set(reflect.New(f.Type().Elem()))
+ }
slots := make([]any, len(plan.Select.ScanIndexes))
for idx, index := range plan.Select.ScanIndexes {
slots[idx] = v.FieldByIndex(index).Addr().Interface()
diff --git a/select_test.go b/select_test.go
index c3285be..9fcecc3 100644
--- a/select_test.go
+++ b/select_test.go
@@ -218,6 +218,79 @@ func TestSelectWithScanError(t *testing.T) {
})
}
+func TestSelectIntoEmbeddedTypes(t *testing.T) {
+ md := mock.NewDriver()
+ db := sql.OpenDB(md)
+
+ type HasCreatedAt struct {
+ CreatedAt time.Time `db:"created_at"`
+ }
+ type HasUpdatedAt struct {
+ UpdatedAt *time.Time `db:"updated_at"`
+ }
+ type compositeRecord struct {
+ ID int64 `db:"id"`
+ HasCreatedAt
+ // This test specifically wants to see that this field gets initialized
+ // whenever one of the Store.Select methods creates a compositeRecord instance.
+ *HasUpdatedAt
+ }
+ store := oblast.MustNewStore[compositeRecord](
+ oblast.SqliteDialect(),
+ oblast.TableNameIs("composite_records"),
+ oblast.PrimaryKeyIs("id"),
+ )
+
+ t.Run("using Store.Select", func(t *testing.T) {
+ md.ForQuery(`SELECT * FROM composite_records`).
+ ExpectQueryWithArgs(nil...).
+ AndReturnColumns("id", "created_at", "updated_at").
+ WithRow(1, time.Unix(1, 0), time.Unix(3, 0)).
+ WithRow(2, time.Unix(2, 0), nil)
+ records := must.Return(store.Select(db, `SELECT * FROM composite_records`))(t)
+ assert.SliceDeepEqual(t, records,
+ compositeRecord{1, HasCreatedAt{time.Unix(1, 0)}, &HasUpdatedAt{new(time.Unix(3, 0))}},
+ compositeRecord{2, HasCreatedAt{time.Unix(2, 0)}, &HasUpdatedAt{nil}},
+ )
+ })
+
+ t.Run("using Store.SelectWhere", func(t *testing.T) {
+ md.ForQuery(`SELECT "id", "created_at", "updated_at" FROM "composite_records" WHERE TRUE`).
+ ExpectQueryWithArgs(nil...).
+ AndReturnColumns("id", "created_at", "updated_at").
+ WithRow(1, time.Unix(1, 0), time.Unix(3, 0)).
+ WithRow(2, time.Unix(2, 0), nil)
+ records := must.Return(store.SelectWhere(db, `TRUE`))(t)
+ assert.SliceDeepEqual(t, records,
+ compositeRecord{1, HasCreatedAt{time.Unix(1, 0)}, &HasUpdatedAt{new(time.Unix(3, 0))}},
+ compositeRecord{2, HasCreatedAt{time.Unix(2, 0)}, &HasUpdatedAt{nil}},
+ )
+ })
+
+ t.Run("using Store.SelectOne", func(t *testing.T) {
+ md.ForQuery(`SELECT * FROM composite_records`).
+ ExpectQueryWithArgs(nil...).
+ AndReturnColumns("id", "created_at", "updated_at").
+ WithRow(1, time.Unix(1, 0), time.Unix(3, 0)).
+ WithRow(2, time.Unix(2, 0), nil)
+ record := must.Return(store.SelectOne(db, `SELECT * FROM composite_records`))(t)
+ assert.DeepEqual(t, record,
+ compositeRecord{1, HasCreatedAt{time.Unix(1, 0)}, &HasUpdatedAt{new(time.Unix(3, 0))}},
+ )
+ })
+
+ t.Run("using Store.SelectOneWhere", func(t *testing.T) {
+ md.ForQuery(`SELECT "id", "created_at", "updated_at" FROM "composite_records" WHERE TRUE`).
+ ExpectQueryWithArgs(nil...).
+ AndReturnColumns("id", "created_at", "updated_at").
+ WithRow(1, time.Unix(1, 0), time.Unix(3, 0)).
+ WithRow(2, time.Unix(2, 0), nil)
+ record := must.Return(store.SelectOneWhere(db, `TRUE`))(t)
+ assert.DeepEqual(t, record,
+ compositeRecord{1, HasCreatedAt{time.Unix(1, 0)}, &HasUpdatedAt{new(time.Unix(3, 0))}},
+ )
+ })
+}
+
// TODO: test error capture during Rows.Close()
// TODO: check for maximum test coverage in select.go
-// TODO: test that, during Select(), assignment into embedded fields with pointer-to-struct type works (docs say that this might panic if we do not allocate into the pointer first)