1
1
package queries
2
2
3
3
import (
4
- "errors "
4
+ "database/sql "
5
5
"fmt"
6
6
"reflect"
7
7
)
8
8
9
- // TODO: consider merging ScanOne() + ScanAll() -> Scan().
10
-
11
9
type Rows interface {
12
- Scan (... any ) error
13
10
Columns () ([]string , error )
14
11
Next () bool
12
+ Scan (... any ) error
15
13
Err () error
16
14
}
17
15
18
- func ScanOne (dst any , rows Rows ) error {
19
- v := reflect .ValueOf (dst )
20
- if ! v .IsValid () || v .Kind () != reflect .Ptr || v .Elem ().Kind () != reflect .Struct || v .IsNil () {
21
- panic ("queries: dst must be a non-nil struct pointer" )
22
- }
23
-
24
- fields := parseStruct (v .Elem ())
25
-
26
- columns , err := rows .Columns ()
27
- if err != nil {
28
- return fmt .Errorf ("getting column names: %w" , err )
29
- }
30
-
31
- target := make ([]any , len (columns ))
32
- for i , column := range columns {
33
- field , ok := fields [column ]
34
- if ! ok {
35
- panic (fmt .Sprintf ("queries: no field for the %#q column" , column ))
36
- }
37
- target [i ] = field
38
- }
39
-
40
- if ! rows .Next () {
41
- return errors .New ("queries: no rows to scan" )
42
- }
43
- if err := rows .Scan (target ... ); err != nil {
44
- return fmt .Errorf ("scanning rows: %w" , err )
45
- }
16
+ func Scan [T any ](dst * []T , rows Rows ) error {
17
+ return scan [T ](reflect .ValueOf (dst ).Elem (), rows )
18
+ }
46
19
47
- return rows .Err ()
20
+ func ScanRow [T any ](dst * T , rows Rows ) error {
21
+ return scan [T ](reflect .ValueOf (dst ).Elem (), rows )
48
22
}
49
23
50
- func ScanAll ( dst any , rows Rows ) error {
51
- v := reflect .ValueOf ( dst )
52
- if ! v . IsValid () || v . Kind () != reflect . Ptr || v . Elem (). Kind () != reflect . Slice || v . Elem (). Type (). Elem () .Kind () != reflect .Struct {
53
- panic ("queries: dst must be a pointer to a slice of structs " )
24
+ func scan [ T any ]( v reflect. Value , rows Rows ) error {
25
+ typ := reflect .TypeFor [ T ]( )
26
+ if typ .Kind () != reflect .Struct {
27
+ panic ("queries: T must be a struct " )
54
28
}
55
29
56
- slice := v .Elem ()
57
- typ := slice .Type ().Elem ()
58
- elem := reflect .New (typ ).Elem ()
59
- fields := parseStruct (elem )
30
+ strct := reflect .New (typ ).Elem ()
31
+ fields := parseStruct (strct )
60
32
61
33
columns , err := rows .Columns ()
62
34
if err != nil {
63
35
return fmt .Errorf ("getting column names: %w" , err )
64
36
}
65
37
66
- target := make ([]any , len (columns ))
38
+ into := make ([]any , len (columns ))
67
39
for i , column := range columns {
68
40
field , ok := fields [column ]
69
41
if ! ok {
70
- panic (fmt .Sprintf ("queries: no field for the %#q column " , column ))
42
+ panic (fmt .Sprintf ("queries: no field for column %q " , column ))
71
43
}
72
- target [i ] = field
44
+ into [i ] = field
73
45
}
74
46
47
+ slice := reflect .New (reflect .SliceOf (typ )).Elem ()
75
48
for rows .Next () {
76
- if err := rows .Scan (target ... ); err != nil {
77
- return fmt .Errorf ("scanning rows: %w" , err )
49
+ if err := rows .Scan (into ... ); err != nil {
50
+ return fmt .Errorf ("scanning row: %w" , err )
51
+ }
52
+ slice = reflect .Append (slice , strct )
53
+ }
54
+ if err := rows .Err (); err != nil {
55
+ return err
56
+ }
57
+
58
+ switch v .Kind () {
59
+ case reflect .Slice :
60
+ v .Set (slice )
61
+ case reflect .Struct :
62
+ if slice .Len () == 0 {
63
+ return sql .ErrNoRows
78
64
}
79
- slice .Set (reflect .Append (slice , elem ))
65
+ v .Set (slice .Index (0 ))
66
+ default :
67
+ panic ("unreachable" )
80
68
}
81
69
82
- return rows . Err ()
70
+ return nil
83
71
}
84
72
85
- // TODO: support nested structs.
86
73
func parseStruct (v reflect.Value ) map [string ]any {
87
74
fields := make (map [string ]any , v .NumField ())
88
75
@@ -92,16 +79,15 @@ func parseStruct(v reflect.Value) map[string]any {
92
79
continue
93
80
}
94
81
95
- sf := v .Type ().Field (i )
96
- name , ok := sf .Tag .Lookup ("sql" )
82
+ tag , ok := v .Type ().Field (i ).Tag .Lookup ("sql" )
97
83
if ! ok {
98
84
continue
99
85
}
100
- if name == "" {
101
- panic (fmt .Sprintf ("queries: %s field has an empty `sql` tag" , sf .Name ))
86
+ if tag == "" {
87
+ panic (fmt .Sprintf ("queries: field %s has an empty `sql` tag" , v . Type (). Field ( i ) .Name ))
102
88
}
103
89
104
- fields [name ] = field .Addr ().Interface ()
90
+ fields [tag ] = field .Addr ().Interface ()
105
91
}
106
92
107
93
return fields
0 commit comments