-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquery.go
124 lines (104 loc) · 2.3 KB
/
query.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
package queries
import (
"context"
"database/sql"
"fmt"
"iter"
"reflect"
)
type queryer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
func Query[T any](ctx context.Context, q queryer, query string, args ...any) iter.Seq2[T, error] {
return func(yield func(T, error) bool) {
rows, err := q.QueryContext(ctx, query, args...)
if err != nil {
yield(zero[T](), err)
return
}
defer rows.Close()
for rows.Next() {
t, err := scan[T](rows)
if err != nil {
yield(zero[T](), err)
return
}
if !yield(t, nil) {
return
}
}
if err := rows.Err(); err != nil {
yield(zero[T](), err)
return
}
}
}
func QueryRow[T any](ctx context.Context, q queryer, query string, args ...any) (T, error) {
rows, err := q.QueryContext(ctx, query, args...)
if err != nil {
return zero[T](), err
}
defer rows.Close()
if !rows.Next() {
if err := rows.Err(); err != nil {
return zero[T](), err
}
return zero[T](), sql.ErrNoRows
}
t, err := scan[T](rows)
if err != nil {
return zero[T](), err
}
if err := rows.Err(); err != nil {
return zero[T](), err
}
return t, nil
}
func zero[T any]() (t T) { return t }
type rows interface {
Columns() ([]string, error)
Scan(...any) error
}
func scan[T any](rows rows) (T, error) {
var t T
v := reflect.ValueOf(&t).Elem()
if v.Kind() != reflect.Struct {
panic("queries: T must be a struct")
}
columns, err := rows.Columns()
if err != nil {
return zero[T](), fmt.Errorf("getting column names: %w", err)
}
fields := parseStruct(v)
args := make([]any, len(columns))
for i, column := range columns {
field, ok := fields[column]
if !ok {
panic(fmt.Sprintf("queries: no field for column %q", column))
}
args[i] = field
}
if err := rows.Scan(args...); err != nil {
return zero[T](), err
}
return t, nil
}
// TODO: add sync.Map cache.
func parseStruct(v reflect.Value) map[string]any {
fields := make(map[string]any, v.NumField())
for i := range v.NumField() {
field := v.Field(i)
if !field.CanSet() {
continue
}
tag, ok := v.Type().Field(i).Tag.Lookup("sql")
if !ok {
continue
}
if tag == "" {
panic(fmt.Sprintf("queries: field %s has an empty `sql` tag", v.Type().Field(i).Name))
}
fields[tag] = field.Addr().Interface()
}
return fields
}