forked from letsencrypt/boulder
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gorm.go
224 lines (197 loc) · 7.62 KB
/
gorm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
package db
import (
"context"
"database/sql"
"fmt"
"reflect"
"regexp"
"strings"
)
// Characters allowed in an unquoted identifier by MariaDB.
// https://mariadb.com/kb/en/identifier-names/#unquoted
var mariaDBUnquotedIdentifierRE = regexp.MustCompile("^[0-9a-zA-Z$_]+$")
func validMariaDBUnquotedIdentifier(s string) error {
if !mariaDBUnquotedIdentifierRE.MatchString(s) {
return fmt.Errorf("invalid MariaDB identifier %q", s)
}
allNumeric := true
startsNumeric := false
for i, c := range []byte(s) {
if c < '0' || c > '9' {
if startsNumeric && len(s) > i && s[i] == 'e' {
return fmt.Errorf("MariaDB identifier looks like floating point: %q", s)
}
allNumeric = false
break
}
startsNumeric = true
}
if allNumeric {
return fmt.Errorf("MariaDB identifier contains only numerals: %q", s)
}
return nil
}
// NewMappedSelector returns an object which can be used to automagically query
// the provided type-mapped database for rows of the parameterized type.
func NewMappedSelector[T any](executor MappedExecutor) (MappedSelector[T], error) {
var throwaway T
t := reflect.TypeOf(throwaway)
// We use a very strict mapping of struct fields to table columns here:
// - The struct must not have any embedded structs, only named fields.
// - The struct field names must be case-insensitively identical to the
// column names (no struct tags necessary).
// - The struct field names must be case-insensitively unique.
// - Every field of the struct must correspond to a database column.
// - Note that the reverse is not true: it's perfectly okay for there to be
// database columns which do not correspond to fields in the struct; those
// columns will be ignored.
// TODO: In the future, when we replace borp's TableMap with our own, this
// check should be performed at the time the mapping is declared.
columns := make([]string, 0)
seen := make(map[string]struct{})
for i := range t.NumField() {
field := t.Field(i)
if field.Anonymous {
return nil, fmt.Errorf("struct contains anonymous embedded struct %q", field.Name)
}
column := strings.ToLower(t.Field(i).Name)
err := validMariaDBUnquotedIdentifier(column)
if err != nil {
return nil, fmt.Errorf("struct field maps to unsafe db column name %q", column)
}
if _, found := seen[column]; found {
return nil, fmt.Errorf("struct fields map to duplicate column name %q", column)
}
seen[column] = struct{}{}
columns = append(columns, column)
}
return &mappedSelector[T]{wrapped: executor, columns: columns}, nil
}
type mappedSelector[T any] struct {
wrapped MappedExecutor
columns []string
}
// QueryContext performs a SELECT on the appropriate table for T. It combines the best
// features of borp, the go stdlib, and generics, using the type parameter of
// the typeSelector object to automatically look up the proper table name and
// columns to select. It returns an iterable which yields fully-populated
// objects of the parameterized type directly. The given clauses MUST be only
// the bits of a sql query from "WHERE ..." onwards; if they contain any of the
// "SELECT ... FROM ..." portion of the query it will result in an error. The
// args take the same kinds of values as borp's SELECT: either one argument per
// positional placeholder, or a map of placeholder names to their arguments
// (see https://pkg.go.dev/github.com/letsencrypt/borp#readme-ad-hoc-sql).
//
// The caller is responsible for calling `Rows.Close()` when they are done with
// the query. The caller is also responsible for ensuring that the clauses
// argument does not contain any user-influenced input.
func (ts mappedSelector[T]) QueryContext(ctx context.Context, clauses string, args ...interface{}) (Rows[T], error) {
// Look up the table to use based on the type of this TypeSelector.
var throwaway T
tableMap, err := ts.wrapped.TableFor(reflect.TypeOf(throwaway), false)
if err != nil {
return nil, fmt.Errorf("database model type not mapped to table name: %w", err)
}
return ts.QueryFrom(ctx, tableMap.TableName, clauses, args...)
}
// QueryFrom is the same as Query, but it additionally takes a table name to
// select from, rather than automatically computing the table name from borp's
// DbMap.
//
// The caller is responsible for calling `Rows.Close()` when they are done with
// the query. The caller is also responsible for ensuring that the clauses
// argument does not contain any user-influenced input.
func (ts mappedSelector[T]) QueryFrom(ctx context.Context, tablename string, clauses string, args ...interface{}) (Rows[T], error) {
err := validMariaDBUnquotedIdentifier(tablename)
if err != nil {
return nil, err
}
// Construct the query from the column names, table name, and given clauses.
// Note that the column names here are in the order given by
query := fmt.Sprintf(
"SELECT %s FROM %s %s",
strings.Join(ts.columns, ", "),
tablename,
clauses,
)
r, err := ts.wrapped.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("reading db: %w", err)
}
return &rows[T]{wrapped: r, numCols: len(ts.columns)}, nil
}
// rows is a wrapper around the stdlib's sql.rows, but with a more
// type-safe method to get actual row content.
type rows[T any] struct {
wrapped *sql.Rows
numCols int
}
// ForEach calls the given function with each model object retrieved by
// repeatedly calling .Get(). It closes the rows object when it hits an error
// or finishes iterating over the rows, so it can only be called once. This is
// the intended way to use the result of QueryContext or QueryFrom; the other
// methods on this type are lower-level and intended for advanced use only.
func (r rows[T]) ForEach(do func(*T) error) (err error) {
defer func() {
// Close the row reader when we exit. Use the named error return to combine
// any error from normal execution with any error from closing.
closeErr := r.Close()
if closeErr != nil && err != nil {
err = fmt.Errorf("%w; also while closing the row reader: %w", err, closeErr)
} else if closeErr != nil {
err = closeErr
}
// If closeErr is nil, then just leaving the existing named return alone
// will do the right thing.
}()
for r.Next() {
row, err := r.Get()
if err != nil {
return fmt.Errorf("reading row: %w", err)
}
err = do(row)
if err != nil {
return err
}
}
err = r.Err()
if err != nil {
return fmt.Errorf("iterating over row reader: %w", err)
}
return nil
}
// Next is a wrapper around sql.Rows.Next(). It must be called before every call
// to Get(), including the first.
func (r rows[T]) Next() bool {
return r.wrapped.Next()
}
// Get is a wrapper around sql.Rows.Scan(). Rather than populating an arbitrary
// number of &interface{} arguments, it returns a populated object of the
// parameterized type.
func (r rows[T]) Get() (*T, error) {
result := new(T)
v := reflect.ValueOf(result)
// Because sql.Rows.Scan(...) takes a variadic number of individual targets to
// read values into, build a slice that can be splatted into the call. Use the
// pre-computed list of in-order column names to populate it.
scanTargets := make([]interface{}, r.numCols)
for i := range scanTargets {
field := v.Elem().Field(i)
scanTargets[i] = field.Addr().Interface()
}
err := r.wrapped.Scan(scanTargets...)
if err != nil {
return nil, fmt.Errorf("reading db row: %w", err)
}
return result, nil
}
// Err is a wrapper around sql.Rows.Err(). It should be checked immediately
// after Next() returns false for any reason.
func (r rows[T]) Err() error {
return r.wrapped.Err()
}
// Close is a wrapper around sql.Rows.Close(). It must be called when the caller
// is done reading rows, regardless of success or error.
func (r rows[T]) Close() error {
return r.wrapped.Close()
}