diff options
| -rw-r--r-- | internal/assert/assert.go | 13 | ||||
| -rw-r--r-- | plan.go | 9 | ||||
| -rw-r--r-- | select.go | 14 | ||||
| -rw-r--r-- | select_test.go | 75 |
4 files changed, 106 insertions, 5 deletions
diff --git a/internal/assert/assert.go b/internal/assert/assert.go index 84b6ecf..99af59c 100644 --- a/internal/assert/assert.go +++ b/internal/assert/assert.go @@ -36,3 +36,16 @@ func SliceEqual[V comparable](t testing.TB, actual []V, expected ...V) { } } } + +// SliceDeepEqual is a test assertion. +func SliceDeepEqual[V any](t testing.TB, actual []V, expected ...V) { + t.Helper() + if len(actual) != len(expected) { + t.Errorf("length mismatch: expected %#v, but got %#v", expected, actual) + } + for idx := range actual { + if !reflect.DeepEqual(actual[idx], expected[idx]) { + t.Errorf("element %d: expected %#v, but got %#v", idx, expected[idx], actual[idx]) + } + } +} @@ -21,8 +21,10 @@ type plan struct { PrimaryKeyColumnNames []string // from info.PrimaryKeyIs marker (if any) AutoColumnNames []string // subset of AllColumnNames where field has `,auto` marker - // Argument for reflect.Value.FieldByIndex() for each column name. + // Field index (i.e. argument for reflect.Value.FieldByIndex()) for each column name. IndexByColumnName map[string][]int + // Indexes of pointer-typed fields that need to be initialized before scanning into this type. + IndexesOfTransparentPointerStructs [][]int // In dialects with UsesLastInsertID() == true, whether the ID column must be written with reflect.Value.SetInt() or reflect.Value.SetUint(). FillIDWithSetUint bool @@ -85,6 +87,11 @@ func buildPlan(t reflect.Type, dialect Dialect, opts planOpts) (plan, error) { if field.Type.Kind() == reflect.Struct || (field.Type.Kind() == reflect.Pointer && field.Type.Elem().Kind() == reflect.Struct) { if field.Tag.Get("db") == "" { indexesOfUnusedTransparentStructs = append(indexesOfUnusedTransparentStructs, field.Index) + if field.Type.Kind() == reflect.Pointer { + // remember that, when scanning into a record of type `t`, we need to write a non-nil zeroed struct into this field + // to enable taking an address of its mapped member fields + p.IndexesOfTransparentPointerStructs = append(p.IndexesOfTransparentPointerStructs, field.Index) + } continue } indexesOfOpaqueStructs = append(indexesOfOpaqueStructs, field.Index) @@ -29,7 +29,7 @@ func (s Store[R]) Select(db Handle, query string, args ...any) (result []R, retu slots := make([]any, len(indexes)) for rows.Next() { var target R - err = collectRow(rows, reflect.ValueOf(&target).Elem(), slots, indexes) + err = collectRow(rows, s.plan, reflect.ValueOf(&target).Elem(), slots, indexes) if err != nil { return nil, err } @@ -67,7 +67,7 @@ func (s Store[R]) SelectWhere(db Handle, partialQuery string, args ...any) (resu slots := make([]any, len(indexes)) for rows.Next() { var target R - err = collectRow(rows, reflect.ValueOf(&target).Elem(), slots, indexes) + err = collectRow(rows, s.plan, reflect.ValueOf(&target).Elem(), slots, indexes) if err != nil { return nil, err } @@ -119,7 +119,11 @@ func startSelectWhereQuery(db Handle, plan plan, partialQuery string, args ...an return rows, plan.Select.ScanIndexes, err } -func collectRow(rows *sql.Rows, v reflect.Value, slots []any, indexes [][]int) error { +func collectRow(rows *sql.Rows, plan plan, v reflect.Value, slots []any, indexes [][]int) error { + for _, index := range plan.IndexesOfTransparentPointerStructs { + f := v.FieldByIndex(index) + f.Set(reflect.New(f.Type().Elem())) + } for idx, index := range indexes { slots[idx] = v.FieldByIndex(index).Addr().Interface() } @@ -178,6 +182,10 @@ func selectOneWhere(db Handle, plan plan, v reflect.Value, partialQuery string, return errors.New("cannot execute SelectOneWhere() because query could not be autogenerated") } query := plan.Select.Query + partialQuery + for _, index := range plan.IndexesOfTransparentPointerStructs { + f := v.FieldByIndex(index) + f.Set(reflect.New(f.Type().Elem())) + } slots := make([]any, len(plan.Select.ScanIndexes)) for idx, index := range plan.Select.ScanIndexes { slots[idx] = v.FieldByIndex(index).Addr().Interface() diff --git a/select_test.go b/select_test.go index c3285be..9fcecc3 100644 --- a/select_test.go +++ b/select_test.go @@ -218,6 +218,79 @@ func TestSelectWithScanError(t *testing.T) { }) } +func TestSelectIntoEmbeddedTypes(t *testing.T) { + md := mock.NewDriver() + db := sql.OpenDB(md) + + type HasCreatedAt struct { + CreatedAt time.Time `db:"created_at"` + } + type HasUpdatedAt struct { + UpdatedAt *time.Time `db:"updated_at"` + } + type compositeRecord struct { + ID int64 `db:"id"` + HasCreatedAt + // This test specifically wants to see that this field gets initialized + // whenever one of the Store.Select methods creates a compositeRecord instance. + *HasUpdatedAt + } + store := oblast.MustNewStore[compositeRecord]( + oblast.SqliteDialect(), + oblast.TableNameIs("composite_records"), + oblast.PrimaryKeyIs("id"), + ) + + t.Run("using Store.Select", func(t *testing.T) { + md.ForQuery(`SELECT * FROM composite_records`). + ExpectQueryWithArgs(nil...). + AndReturnColumns("id", "created_at", "updated_at"). + WithRow(1, time.Unix(1, 0), time.Unix(3, 0)). + WithRow(2, time.Unix(2, 0), nil) + records := must.Return(store.Select(db, `SELECT * FROM composite_records`))(t) + assert.SliceDeepEqual(t, records, + compositeRecord{1, HasCreatedAt{time.Unix(1, 0)}, &HasUpdatedAt{new(time.Unix(3, 0))}}, + compositeRecord{2, HasCreatedAt{time.Unix(2, 0)}, &HasUpdatedAt{nil}}, + ) + }) + + t.Run("using Store.SelectWhere", func(t *testing.T) { + md.ForQuery(`SELECT "id", "created_at", "updated_at" FROM "composite_records" WHERE TRUE`). + ExpectQueryWithArgs(nil...). + AndReturnColumns("id", "created_at", "updated_at"). + WithRow(1, time.Unix(1, 0), time.Unix(3, 0)). + WithRow(2, time.Unix(2, 0), nil) + records := must.Return(store.SelectWhere(db, `TRUE`))(t) + assert.SliceDeepEqual(t, records, + compositeRecord{1, HasCreatedAt{time.Unix(1, 0)}, &HasUpdatedAt{new(time.Unix(3, 0))}}, + compositeRecord{2, HasCreatedAt{time.Unix(2, 0)}, &HasUpdatedAt{nil}}, + ) + }) + + t.Run("using Store.SelectOne", func(t *testing.T) { + md.ForQuery(`SELECT * FROM composite_records`). + ExpectQueryWithArgs(nil...). + AndReturnColumns("id", "created_at", "updated_at"). + WithRow(1, time.Unix(1, 0), time.Unix(3, 0)). + WithRow(2, time.Unix(2, 0), nil) + record := must.Return(store.SelectOne(db, `SELECT * FROM composite_records`))(t) + assert.DeepEqual(t, record, + compositeRecord{1, HasCreatedAt{time.Unix(1, 0)}, &HasUpdatedAt{new(time.Unix(3, 0))}}, + ) + }) + + t.Run("using Store.SelectOneWhere", func(t *testing.T) { + md.ForQuery(`SELECT "id", "created_at", "updated_at" FROM "composite_records" WHERE TRUE`). + ExpectQueryWithArgs(nil...). + AndReturnColumns("id", "created_at", "updated_at"). + WithRow(1, time.Unix(1, 0), time.Unix(3, 0)). + WithRow(2, time.Unix(2, 0), nil) + record := must.Return(store.SelectOneWhere(db, `TRUE`))(t) + assert.DeepEqual(t, record, + compositeRecord{1, HasCreatedAt{time.Unix(1, 0)}, &HasUpdatedAt{new(time.Unix(3, 0))}}, + ) + }) +} + // TODO: test error capture during Rows.Close() // TODO: check for maximum test coverage in select.go -// TODO: test that, during Select(), assignment into embedded fields with pointer-to-struct type works (docs say that this might panic if we do not allocate into the pointer first) |
