Skip to content

Commit 9020f91

Browse files
committed
feat(scanner): new API and tests
1 parent 9450f41 commit 9020f91

File tree

3 files changed

+147
-55
lines changed

3 files changed

+147
-55
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
module go-simpler.org/queries
22

3-
go 1.18
3+
go 1.22

scanner.go

+40-54
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,75 @@
11
package queries
22

33
import (
4-
"errors"
4+
"database/sql"
55
"fmt"
66
"reflect"
77
)
88

9-
// TODO: consider merging ScanOne() + ScanAll() -> Scan().
10-
119
type Rows interface {
12-
Scan(...any) error
1310
Columns() ([]string, error)
1411
Next() bool
12+
Scan(...any) error
1513
Err() error
1614
}
1715

18-
func ScanOne(dst any, rows Rows) error {
19-
v := reflect.ValueOf(dst)
20-
if !v.IsValid() || v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct || v.IsNil() {
21-
panic("queries: dst must be a non-nil struct pointer")
22-
}
23-
24-
fields := parseStruct(v.Elem())
25-
26-
columns, err := rows.Columns()
27-
if err != nil {
28-
return fmt.Errorf("getting column names: %w", err)
29-
}
30-
31-
target := make([]any, len(columns))
32-
for i, column := range columns {
33-
field, ok := fields[column]
34-
if !ok {
35-
panic(fmt.Sprintf("queries: no field for the %#q column", column))
36-
}
37-
target[i] = field
38-
}
39-
40-
if !rows.Next() {
41-
return errors.New("queries: no rows to scan")
42-
}
43-
if err := rows.Scan(target...); err != nil {
44-
return fmt.Errorf("scanning rows: %w", err)
45-
}
16+
func Scan[T any](dst *[]T, rows Rows) error {
17+
return scan[T](reflect.ValueOf(dst).Elem(), rows)
18+
}
4619

47-
return rows.Err()
20+
func ScanRow[T any](dst *T, rows Rows) error {
21+
return scan[T](reflect.ValueOf(dst).Elem(), rows)
4822
}
4923

50-
func ScanAll(dst any, rows Rows) error {
51-
v := reflect.ValueOf(dst)
52-
if !v.IsValid() || v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Slice || v.Elem().Type().Elem().Kind() != reflect.Struct {
53-
panic("queries: dst must be a pointer to a slice of structs")
24+
func scan[T any](v reflect.Value, rows Rows) error {
25+
typ := reflect.TypeFor[T]()
26+
if typ.Kind() != reflect.Struct {
27+
panic("queries: T must be a struct")
5428
}
5529

56-
slice := v.Elem()
57-
typ := slice.Type().Elem()
58-
elem := reflect.New(typ).Elem()
59-
fields := parseStruct(elem)
30+
strct := reflect.New(typ).Elem()
31+
fields := parseStruct(strct)
6032

6133
columns, err := rows.Columns()
6234
if err != nil {
6335
return fmt.Errorf("getting column names: %w", err)
6436
}
6537

66-
target := make([]any, len(columns))
38+
into := make([]any, len(columns))
6739
for i, column := range columns {
6840
field, ok := fields[column]
6941
if !ok {
70-
panic(fmt.Sprintf("queries: no field for the %#q column", column))
42+
panic(fmt.Sprintf("queries: no field for column %q", column))
7143
}
72-
target[i] = field
44+
into[i] = field
7345
}
7446

47+
slice := reflect.New(reflect.SliceOf(typ)).Elem()
7548
for rows.Next() {
76-
if err := rows.Scan(target...); err != nil {
77-
return fmt.Errorf("scanning rows: %w", err)
49+
if err := rows.Scan(into...); err != nil {
50+
return fmt.Errorf("scanning row: %w", err)
51+
}
52+
slice = reflect.Append(slice, strct)
53+
}
54+
if err := rows.Err(); err != nil {
55+
return err
56+
}
57+
58+
switch v.Kind() {
59+
case reflect.Slice:
60+
v.Set(slice)
61+
case reflect.Struct:
62+
if slice.Len() == 0 {
63+
return sql.ErrNoRows
7864
}
79-
slice.Set(reflect.Append(slice, elem))
65+
v.Set(slice.Index(0))
66+
default:
67+
panic("unreachable")
8068
}
8169

82-
return rows.Err()
70+
return nil
8371
}
8472

85-
// TODO: support nested structs.
8673
func parseStruct(v reflect.Value) map[string]any {
8774
fields := make(map[string]any, v.NumField())
8875

@@ -92,16 +79,15 @@ func parseStruct(v reflect.Value) map[string]any {
9279
continue
9380
}
9481

95-
sf := v.Type().Field(i)
96-
name, ok := sf.Tag.Lookup("sql")
82+
tag, ok := v.Type().Field(i).Tag.Lookup("sql")
9783
if !ok {
9884
continue
9985
}
100-
if name == "" {
101-
panic(fmt.Sprintf("queries: %s field has an empty `sql` tag", sf.Name))
86+
if tag == "" {
87+
panic(fmt.Sprintf("queries: field %s has an empty `sql` tag", v.Type().Field(i).Name))
10288
}
10389

104-
fields[name] = field.Addr().Interface()
90+
fields[tag] = field.Addr().Interface()
10591
}
10692

10793
return fields

scanner_test.go

+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package queries_test
2+
3+
import (
4+
"database/sql"
5+
"reflect"
6+
"testing"
7+
8+
"go-simpler.org/queries"
9+
"go-simpler.org/queries/internal/assert"
10+
. "go-simpler.org/queries/internal/assert/EF"
11+
)
12+
13+
func Test_misuse(t *testing.T) {
14+
t.Run("non-struct T", func(t *testing.T) {
15+
const panicMsg = "queries: T must be a struct"
16+
17+
assert.Panics[E](t, func() { _ = queries.Scan(new([]int), nil) }, panicMsg)
18+
assert.Panics[E](t, func() { _ = queries.ScanRow(new(int), nil) }, panicMsg)
19+
})
20+
21+
t.Run("empty tag", func(t *testing.T) {
22+
const panicMsg = "queries: field Foo has an empty `sql` tag"
23+
24+
type dst struct {
25+
Foo int `sql:""`
26+
}
27+
assert.Panics[E](t, func() { _ = queries.Scan(new([]dst), nil) }, panicMsg)
28+
assert.Panics[E](t, func() { _ = queries.ScanRow(new(dst), nil) }, panicMsg)
29+
})
30+
31+
t.Run("missing field", func(t *testing.T) {
32+
const panicMsg = `queries: no field for column "foo"`
33+
34+
rows := mockRows{columns: []string{"foo"}}
35+
36+
type dst struct {
37+
Foo int
38+
}
39+
assert.Panics[E](t, func() { _ = queries.Scan(new([]dst), &rows) }, panicMsg)
40+
assert.Panics[E](t, func() { _ = queries.ScanRow(new(dst), &rows) }, panicMsg)
41+
})
42+
}
43+
44+
func TestScan(t *testing.T) {
45+
rows := mockRows{
46+
columns: []string{"foo", "bar"},
47+
values: [][]any{{1, "A"}, {2, "B"}},
48+
}
49+
50+
var dst []struct {
51+
Foo int `sql:"foo"`
52+
Bar string `sql:"bar"`
53+
}
54+
err := queries.Scan(&dst, &rows)
55+
assert.NoErr[F](t, err)
56+
assert.Equal[E](t, len(dst), 2)
57+
assert.Equal[E](t, dst[0].Foo, 1)
58+
assert.Equal[E](t, dst[1].Foo, 2)
59+
assert.Equal[E](t, dst[0].Bar, "A")
60+
assert.Equal[E](t, dst[1].Bar, "B")
61+
}
62+
63+
func TestScanRow(t *testing.T) {
64+
rows := mockRows{
65+
columns: []string{"foo", "bar"},
66+
values: [][]any{{1, "A"}},
67+
}
68+
69+
var dst struct {
70+
Foo int `sql:"foo"`
71+
Bar string `sql:"bar"`
72+
}
73+
err := queries.ScanRow(&dst, &rows)
74+
assert.NoErr[F](t, err)
75+
assert.Equal[E](t, dst.Foo, 1)
76+
assert.Equal[E](t, dst.Bar, "A")
77+
78+
t.Run("no rows", func(t *testing.T) {
79+
rows := mockRows{columns: []string{"foo"}}
80+
81+
var dst struct {
82+
Foo int `sql:"foo"`
83+
}
84+
err := queries.ScanRow(&dst, &rows)
85+
assert.IsErr[E](t, err, sql.ErrNoRows)
86+
})
87+
}
88+
89+
type mockRows struct {
90+
columns []string
91+
values [][]any
92+
idx int
93+
}
94+
95+
func (r *mockRows) Columns() ([]string, error) { return r.columns, nil }
96+
func (r *mockRows) Next() bool { return r.idx < len(r.values) }
97+
func (r *mockRows) Err() error { return nil }
98+
99+
func (r *mockRows) Scan(dst ...any) error {
100+
for i := 0; i < len(dst); i++ {
101+
v := reflect.ValueOf(r.values[r.idx][i])
102+
reflect.ValueOf(dst[i]).Elem().Set(v)
103+
}
104+
r.idx++
105+
return nil
106+
}

0 commit comments

Comments
 (0)