aboutsummaryrefslogtreecommitdiff
path: root/handle.go
diff options
context:
space:
mode:
Diffstat (limited to 'handle.go')
-rw-r--r--handle.go57
1 files changed, 35 insertions, 22 deletions
diff --git a/handle.go b/handle.go
index b7f8608..bb251bb 100644
--- a/handle.go
+++ b/handle.go
@@ -16,9 +16,25 @@ import (
// Custom implementations of this interface can be used to connect non-std database drivers to Oblast.
type Handle = handle.Handle
-// StdHandle is an interface covered by both [*sql.DB] and [*sql.Tx].
+// SqlHandle wraps types like [*sql.DB] or [*sql.Tx] into a [Handle] that can be used with Oblast.
+type SqlHandle[T SqlExecutor] struct {
+ // The original database or transaction handle.
+ // It is safe to read this field to execute operations that Oblast does not handle (e.g. transactions, savepoints or OLAP queries).
+ Base T
+
+ // If this is not true, then any methods on this type will panic.
+ // This is just to enforce that the handle is constructed with Wrap(), thus guaranteeing future compatibility if actual important private struct fields are added later.
+ ok bool
+}
+
+// Wrap converts an [*sql.DB] or [*sql.Tx] into a [Handle] that can be used with Oblast functions.
+func Wrap[T SqlExecutor](dbOrTx T) SqlHandle[T] {
+ return SqlHandle[T]{Base: dbOrTx, ok: true}
+}
+
+// SqlExecutor is an interface covered by both [*sql.DB] and [*sql.Tx].
// It appears in the signature of function [Wrap].
-type StdHandle interface {
+type SqlExecutor interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
@@ -27,38 +43,35 @@ type StdHandle interface {
// static assertion that the respective types implement the interface
var (
- _ StdHandle = &sql.DB{}
- _ StdHandle = &sql.Tx{}
+ _ SqlExecutor = &sql.DB{}
+ _ SqlExecutor = &sql.Tx{}
)
-// Wrap converts an [*sql.DB] or [*sql.Tx] into a [Handle] that can be used with Oblast functions.
-func Wrap(dbOrTx StdHandle) Handle {
- return wrappedHandle{dbOrTx}
-}
-
-type wrappedHandle struct {
- db StdHandle
-}
-
-// Prepare implements the [Handle] interface.
-func (h wrappedHandle) Prepare(ctx context.Context, query string, repeated bool) (handle.Statement, error) {
+// OblastPrepare implements the [Handle] interface.
+func (h SqlHandle[T]) OblastPrepare(ctx context.Context, query string, repeated bool) (handle.Statement, error) {
+ if !h.ok {
+ panic("SqlHandle was not constructed through oblast.Wrap()!")
+ }
if !repeated {
- return wrappedStatement{h.db, query, nil}, nil
+ return wrappedStatement{h.Base, query, nil}, nil
}
- stmt, err := h.db.PrepareContext(ctx, query)
+ stmt, err := h.Base.PrepareContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("during Prepare(): %w", err)
}
- return wrappedStatement{h.db, query, stmt}, nil
+ return wrappedStatement{h.Base, query, stmt}, nil
}
-// Query implements the [Handle] interface.
-func (h wrappedHandle) Query(ctx context.Context, query string, args []any) (handle.Rows, error) {
- return h.db.QueryContext(ctx, query, args...) //nolint:rowserrcheck // the caller does the check
+// OblastQuery implements the [Handle] interface.
+func (h SqlHandle[T]) OblastQuery(ctx context.Context, query string, args []any) (handle.Rows, error) {
+ if !h.ok {
+ panic("SqlHandle was not constructed through oblast.Wrap()!")
+ }
+ return h.Base.QueryContext(ctx, query, args...) //nolint:rowserrcheck // the caller does the check
}
type wrappedStatement struct {
- db StdHandle
+ db SqlExecutor
query string
stmt *sql.Stmt // nil if repeated = false
}