aboutsummaryrefslogtreecommitdiff
path: root/benchmark/internal/oblast_pgx/handle.go
blob: 7ccc9ea504a42dd336f88f3c477bed7b9f87fab5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
// SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net>
// SPDX-License-Identifier: Apache-2.0

package oblast_pgx

import (
	"context"
	"fmt"
	"strconv"
	"sync/atomic"

	"github.com/jackc/pgx/v5"
	"github.com/jackc/pgx/v5/pgconn"
	"go.xyrillian.de/oblast/handle"
)

type Handle interface {
	Exec(ctx context.Context, sql string, args ...any) (pgconn.CommandTag, error)
	Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
	QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
}

var (
	_ Handle = &pgx.Conn{}
	_ Handle = pgx.Tx(nil)
)

// TODO: offer wrapping for pgxpool.Pool and pgxpool.Conn?
func Wrap(h Handle) handle.Handle {
	switch h := h.(type) {
	case *pgx.Conn:
		return wrappedHandle{h}
	case pgx.Tx:
		return wrappedHandle{h}
	default:
		panic(fmt.Sprintf("unexpected type: %#v", h))
	}
}

var preparedStatementId atomic.Uint64

type wrappedHandle struct {
	inner Handle
}

// OblastPrepare implements the [handle.Handle] interface.
func (h wrappedHandle) OblastPrepare(ctx context.Context, query string, repeated bool) (handle.Statement, error) {
	if !repeated {
		return wrappedUnpreparedStatement{query, h.inner}, nil
	}

	name := "oblast_pgx_" + strconv.FormatUint(preparedStatementId.Add(1), 10)
	switch inner := h.inner.(type) {
	case *pgx.Conn:
		stmt, err := inner.Prepare(ctx, name, query)
		return wrappedPreparedStatement{ctx, stmt, h.inner}, err
	case pgx.Tx:
		stmt, err := inner.Conn().Prepare(ctx, name, query)
		return wrappedPreparedStatement{ctx, stmt, h.inner}, err
	default:
		panic("unreachable") // because of the check in func Wrap()
	}
}

// Releases a prepared statement.
func deallocate(ctx context.Context, h Handle, stmt *pgconn.StatementDescription) error {
	switch h := h.(type) {
	case *pgx.Conn:
		return h.Deallocate(ctx, stmt.Name)
	case pgx.Tx:
		return h.Conn().Deallocate(ctx, stmt.Name)
	default:
		panic("unreachable") // because of the check in func Wrap()
	}
}

// OblastQuery implements the [handle.Handle] interface.
func (h wrappedHandle) OblastQuery(ctx context.Context, query string, args []any) (handle.Rows, error) {
	rows, err := h.inner.Query(ctx, query, args...)
	return wrappedRows{rows}, err
}