aboutsummaryrefslogtreecommitdiff
path: root/dialect.go
blob: 6505199fb57a1d8f5ad52f5aeb09874c0a3a3c32 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
// SPDX-FileCopyrightText: 2026 Stefan Majewsky <majewsky@gmx.net>
// SPDX-License-Identifier: Apache-2.0

package oblast

import (
	"fmt"
	"strconv"
	"strings"
)

// Dialect accounts for differences between different SQL dialects
// that are relevant to query generation within Oblast.
//
// # Compatibility notice
//
// This interface may be extended, even within minor versions, when doing so is
// required to add support for new DB dialects that differ from previously
// supported dialects in unexpected ways.
type Dialect interface {
	// Placeholder returns the placeholder for the i-th query argument.
	// Most dialects use "?", but e.g. PostgreSQL uses "$1", "$2" and so on.
	// The argument numbers from 0 like a slice index.
	Placeholder(i int) string

	// QuoteIdentifier wraps the name of a column or table in quotes,
	// in order to avoid the name from being interpreted as a keyword.
	QuoteIdentifier(name string) string

	// UpsertClause generates an "ON CONFLICT" or similar clause
	// that can be appended to an INSERT query to make it fall back to
	// behave like UPDATE if a record with the same primary key already exists.
	// This is only used for record types that have a primary key.
	UpsertClause(pkColumns, otherColumns []string) string
}

// MariaDBDialect is the dialect of MariaDB 10.5+ databases.
//
// This dialect does NOT support MySQL, as well as ancient MariaDB versions (10.5 was released 2020-06-24),
// because those do not understand the "INSERT ... RETURNING" syntax.
func MariaDBDialect() Dialect {
	return mariadbDialect{}
}

type mariadbDialect struct{}

func (mariadbDialect) Placeholder(_ int) string           { return "?" }
func (mariadbDialect) QuoteIdentifier(name string) string { return "`" + name + "`" }

func (d mariadbDialect) UpsertClause(pkColumns, otherColumns []string) string {
	clauses := make([]string, max(1, len(otherColumns)))
	if len(otherColumns) == 0 {
		// we need at least one UPDATE clause; if there are no non-PK columns,
		// we can just use one of the PK columns, updating those is a safe no-op
		clauses[0] = fmt.Sprintf(`%[1]s = VALUES(%[1]s)`, d.QuoteIdentifier(pkColumns[0]))
	} else {
		for idx, name := range otherColumns {
			clauses[idx] = fmt.Sprintf(`%[1]s = VALUES(%[1]s)`, d.QuoteIdentifier(name))
		}
	}
	return ` ON DUPLICATE KEY UPDATE ` + strings.Join(clauses, ", ")
}

// PostgresDialect is the dialect of PostgreSQL databases.
func PostgresDialect() Dialect {
	return postgresDialect{}
}

type postgresDialect struct{}

func (postgresDialect) Placeholder(i int) string           { return "$" + strconv.Itoa(i+1) }
func (postgresDialect) QuoteIdentifier(name string) string { return `"` + name + `"` }

func (d postgresDialect) UpsertClause(pkColumns, otherColumns []string) string {
	quotedPkColumns := make([]string, len(pkColumns))
	for idx, name := range pkColumns {
		quotedPkColumns[idx] = d.QuoteIdentifier(name)
	}
	clauses := make([]string, len(otherColumns))
	for idx, name := range otherColumns {
		clauses[idx] = fmt.Sprintf(`%[1]s = EXCLUDED.%[1]s`, d.QuoteIdentifier(name))
	}
	if len(otherColumns) == 0 {
		return fmt.Sprintf(` ON CONFLICT (%s) DO NOTHING`, strings.Join(quotedPkColumns, ", "))
	} else {
		return fmt.Sprintf(` ON CONFLICT (%s) DO UPDATE SET %s`,
			strings.Join(quotedPkColumns, ", "), strings.Join(clauses, ", "))
	}
}

// SqliteDialect is the dialect of SQLite 3.24.0+ databases.
//
// This dialect does NOT support ancient SQLite versions (3.24.0 was released 2018-06-04)
// that do not understand the "INSERT ... RETURNING" syntax.
func SqliteDialect() Dialect {
	return sqliteDialect{}
}

type sqliteDialect struct{}

func (sqliteDialect) Placeholder(_ int) string           { return "?" }
func (sqliteDialect) QuoteIdentifier(name string) string { return `"` + name + `"` }
func (sqliteDialect) UpsertClause(pkColumns, otherColumns []string) string {
	return postgresDialect{}.UpsertClause(pkColumns, otherColumns)
}