Skip to content

Commit 35ee52a

Browse files
authored
feat(scanner)!: Scan -> Query (#2)
1 parent 25ae413 commit 35ee52a

File tree

6 files changed

+201
-212
lines changed

6 files changed

+201
-212
lines changed

example_test.go

+8-11
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,25 @@ func run(ctx context.Context) error {
3535
columns = append(columns, "created_at")
3636
}
3737

38+
// select first_name, last_name, created_at from users where created_at >= $1
3839
var qb queries.Builder
3940
qb.Appendf("select %s from users", strings.Join(columns, ", "))
4041
if true {
4142
qb.Appendf(" where created_at >= %$", time.Date(2024, time.January, 1, 0, 0, 0, 0, time.Local))
4243
}
4344

44-
// select first_name, last_name, created_at from users where created_at >= $1
45-
rows, err := db.QueryContext(ctx, qb.Query(), qb.Args()...)
46-
if err != nil {
47-
return err
48-
}
49-
defer rows.Close()
50-
51-
var users []struct {
45+
type user struct {
5246
FirstName string `sql:"first_name"`
5347
LastName string `sql:"last_name"`
5448
CreatedAt time.Time `sql:"created_at"`
5549
}
56-
if err := queries.Scan(&users, rows); err != nil {
57-
return err
50+
51+
for user, err := range queries.Query[user](ctx, db, qb.Query(), qb.Args()...) {
52+
if err != nil {
53+
return err
54+
}
55+
fmt.Println(user)
5856
}
5957

60-
fmt.Println(users)
6158
return nil
6259
}

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
module go-simpler.org/queries
22

3-
go 1.22
3+
go 1.23

query.go

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package queries
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"fmt"
7+
"iter"
8+
"reflect"
9+
)
10+
11+
type queryer interface {
12+
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
13+
}
14+
15+
func Query[T any](ctx context.Context, q queryer, query string, args ...any) iter.Seq2[T, error] {
16+
return func(yield func(T, error) bool) {
17+
rows, err := q.QueryContext(ctx, query, args...)
18+
if err != nil {
19+
yield(zero[T](), err)
20+
return
21+
}
22+
defer rows.Close()
23+
24+
for rows.Next() {
25+
t, err := scan[T](rows)
26+
if err != nil {
27+
yield(zero[T](), err)
28+
return
29+
}
30+
if !yield(t, nil) {
31+
return
32+
}
33+
}
34+
if err := rows.Err(); err != nil {
35+
yield(zero[T](), err)
36+
return
37+
}
38+
}
39+
}
40+
41+
func QueryRow[T any](ctx context.Context, q queryer, query string, args ...any) (T, error) {
42+
rows, err := q.QueryContext(ctx, query, args...)
43+
if err != nil {
44+
return zero[T](), err
45+
}
46+
defer rows.Close()
47+
48+
if !rows.Next() {
49+
if err := rows.Err(); err != nil {
50+
return zero[T](), err
51+
}
52+
return zero[T](), sql.ErrNoRows
53+
}
54+
55+
t, err := scan[T](rows)
56+
if err != nil {
57+
return zero[T](), err
58+
}
59+
if err := rows.Err(); err != nil {
60+
return zero[T](), err
61+
}
62+
63+
return t, nil
64+
}
65+
66+
func zero[T any]() (t T) { return t }
67+
68+
type rows interface {
69+
Columns() ([]string, error)
70+
Scan(...any) error
71+
}
72+
73+
func scan[T any](rows rows) (T, error) {
74+
var t T
75+
v := reflect.ValueOf(&t).Elem()
76+
if v.Kind() != reflect.Struct {
77+
panic("queries: T must be a struct")
78+
}
79+
80+
columns, err := rows.Columns()
81+
if err != nil {
82+
return zero[T](), fmt.Errorf("getting column names: %w", err)
83+
}
84+
85+
fields := parseStruct(v)
86+
args := make([]any, len(columns))
87+
88+
for i, column := range columns {
89+
field, ok := fields[column]
90+
if !ok {
91+
panic(fmt.Sprintf("queries: no field for column %q", column))
92+
}
93+
args[i] = field
94+
}
95+
if err := rows.Scan(args...); err != nil {
96+
return zero[T](), err
97+
}
98+
99+
return t, nil
100+
}
101+
102+
// TODO: add sync.Map cache.
103+
func parseStruct(v reflect.Value) map[string]any {
104+
fields := make(map[string]any, v.NumField())
105+
106+
for i := range v.NumField() {
107+
field := v.Field(i)
108+
if !field.CanSet() {
109+
continue
110+
}
111+
112+
tag, ok := v.Type().Field(i).Tag.Lookup("sql")
113+
if !ok {
114+
continue
115+
}
116+
if tag == "" {
117+
panic(fmt.Sprintf("queries: field %s has an empty `sql` tag", v.Type().Field(i).Name))
118+
}
119+
120+
fields[tag] = field.Addr().Interface()
121+
}
122+
123+
return fields
124+
}

query_test.go

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package queries
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
7+
"go-simpler.org/queries/internal/assert"
8+
. "go-simpler.org/queries/internal/assert/EF"
9+
)
10+
11+
func Test_scan(t *testing.T) {
12+
t.Run("non-struct T", func(t *testing.T) {
13+
fn := func() { _, _ = scan[int](new(mockRows)) }
14+
assert.Panics[E](t, fn, "queries: T must be a struct")
15+
})
16+
17+
t.Run("empty tag", func(t *testing.T) {
18+
type row struct {
19+
Foo int `sql:""`
20+
}
21+
fn := func() { _, _ = scan[row](new(mockRows)) }
22+
assert.Panics[E](t, fn, "queries: field Foo has an empty `sql` tag")
23+
})
24+
25+
t.Run("missing field", func(t *testing.T) {
26+
rows := mockRows{
27+
columns: []string{"foo", "bar"},
28+
}
29+
30+
type row struct {
31+
Foo int `sql:"foo"`
32+
Bar string
33+
}
34+
fn := func() { _, _ = scan[row](&rows) }
35+
assert.Panics[E](t, fn, `queries: no field for column "bar"`)
36+
})
37+
38+
t.Run("ok", func(t *testing.T) {
39+
rows := mockRows{
40+
columns: []string{"foo", "bar"},
41+
values: []any{1, "A"},
42+
}
43+
44+
type row struct {
45+
Foo int `sql:"foo"`
46+
Bar string `sql:"bar"`
47+
}
48+
r, err := scan[row](&rows)
49+
assert.NoErr[F](t, err)
50+
assert.Equal[E](t, r.Foo, 1)
51+
assert.Equal[E](t, r.Bar, "A")
52+
})
53+
}
54+
55+
type mockRows struct {
56+
columns []string
57+
values []any
58+
}
59+
60+
func (r *mockRows) Columns() ([]string, error) { return r.columns, nil }
61+
62+
func (r *mockRows) Scan(dst ...any) error {
63+
for i := range dst {
64+
v := reflect.ValueOf(r.values[i])
65+
reflect.ValueOf(dst[i]).Elem().Set(v)
66+
}
67+
return nil
68+
}

scanner.go

-94
This file was deleted.

0 commit comments

Comments
 (0)