Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(scanner)!: Scan -> Query #2

Merged
merged 3 commits into from
Mar 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,25 @@ func run(ctx context.Context) error {
columns = append(columns, "created_at")
}

// select first_name, last_name, created_at from users where created_at >= $1
var qb queries.Builder
qb.Appendf("select %s from users", strings.Join(columns, ", "))
if true {
qb.Appendf(" where created_at >= %$", time.Date(2024, time.January, 1, 0, 0, 0, 0, time.Local))
}

// select first_name, last_name, created_at from users where created_at >= $1
rows, err := db.QueryContext(ctx, qb.Query(), qb.Args()...)
if err != nil {
return err
}
defer rows.Close()

var users []struct {
type user struct {
FirstName string `sql:"first_name"`
LastName string `sql:"last_name"`
CreatedAt time.Time `sql:"created_at"`
}
if err := queries.Scan(&users, rows); err != nil {
return err

for user, err := range queries.Query[user](ctx, db, qb.Query(), qb.Args()...) {
if err != nil {
return err
}
fmt.Println(user)
}

fmt.Println(users)
return nil
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
module go-simpler.org/queries

go 1.22
go 1.23
124 changes: 124 additions & 0 deletions query.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package queries

import (
"context"
"database/sql"
"fmt"
"iter"
"reflect"
)

type queryer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}

func Query[T any](ctx context.Context, q queryer, query string, args ...any) iter.Seq2[T, error] {
return func(yield func(T, error) bool) {
rows, err := q.QueryContext(ctx, query, args...)
if err != nil {
yield(zero[T](), err)
return
}
defer rows.Close()

for rows.Next() {
t, err := scan[T](rows)
if err != nil {
yield(zero[T](), err)
return
}
if !yield(t, nil) {
return
}

Check warning on line 32 in query.go

View check run for this annotation

Codecov / codecov/patch

query.go#L15-L32

Added lines #L15 - L32 were not covered by tests
}
if err := rows.Err(); err != nil {
yield(zero[T](), err)
return
}

Check warning on line 37 in query.go

View check run for this annotation

Codecov / codecov/patch

query.go#L34-L37

Added lines #L34 - L37 were not covered by tests
}
}

func QueryRow[T any](ctx context.Context, q queryer, query string, args ...any) (T, error) {
rows, err := q.QueryContext(ctx, query, args...)
if err != nil {
return zero[T](), err
}
defer rows.Close()

if !rows.Next() {
if err := rows.Err(); err != nil {
return zero[T](), err
}
return zero[T](), sql.ErrNoRows

Check warning on line 52 in query.go

View check run for this annotation

Codecov / codecov/patch

query.go#L41-L52

Added lines #L41 - L52 were not covered by tests
}

t, err := scan[T](rows)
if err != nil {
return zero[T](), err
}
if err := rows.Err(); err != nil {
return zero[T](), err
}

Check warning on line 61 in query.go

View check run for this annotation

Codecov / codecov/patch

query.go#L55-L61

Added lines #L55 - L61 were not covered by tests

return t, nil

Check warning on line 63 in query.go

View check run for this annotation

Codecov / codecov/patch

query.go#L63

Added line #L63 was not covered by tests
}

func zero[T any]() (t T) { return t }

Check warning on line 66 in query.go

View check run for this annotation

Codecov / codecov/patch

query.go#L66

Added line #L66 was not covered by tests

type rows interface {
Columns() ([]string, error)
Scan(...any) error
}

func scan[T any](rows rows) (T, error) {
var t T
v := reflect.ValueOf(&t).Elem()
if v.Kind() != reflect.Struct {
panic("queries: T must be a struct")
}

columns, err := rows.Columns()
if err != nil {
return zero[T](), fmt.Errorf("getting column names: %w", err)
}

Check warning on line 83 in query.go

View check run for this annotation

Codecov / codecov/patch

query.go#L82-L83

Added lines #L82 - L83 were not covered by tests

fields := parseStruct(v)
args := make([]any, len(columns))

for i, column := range columns {
field, ok := fields[column]
if !ok {
panic(fmt.Sprintf("queries: no field for column %q", column))
}
args[i] = field
}
if err := rows.Scan(args...); err != nil {
return zero[T](), err
}

Check warning on line 97 in query.go

View check run for this annotation

Codecov / codecov/patch

query.go#L96-L97

Added lines #L96 - L97 were not covered by tests

return t, nil
}

// TODO: add sync.Map cache.
func parseStruct(v reflect.Value) map[string]any {
fields := make(map[string]any, v.NumField())

for i := range v.NumField() {
field := v.Field(i)
if !field.CanSet() {
continue

Check warning on line 109 in query.go

View check run for this annotation

Codecov / codecov/patch

query.go#L109

Added line #L109 was not covered by tests
}

tag, ok := v.Type().Field(i).Tag.Lookup("sql")
if !ok {
continue
}
if tag == "" {
panic(fmt.Sprintf("queries: field %s has an empty `sql` tag", v.Type().Field(i).Name))
}

fields[tag] = field.Addr().Interface()
}

return fields
}
68 changes: 68 additions & 0 deletions query_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package queries

import (
"reflect"
"testing"

"go-simpler.org/queries/internal/assert"
. "go-simpler.org/queries/internal/assert/EF"
)

func Test_scan(t *testing.T) {
t.Run("non-struct T", func(t *testing.T) {
fn := func() { _, _ = scan[int](new(mockRows)) }
assert.Panics[E](t, fn, "queries: T must be a struct")
})

t.Run("empty tag", func(t *testing.T) {
type row struct {
Foo int `sql:""`
}
fn := func() { _, _ = scan[row](new(mockRows)) }
assert.Panics[E](t, fn, "queries: field Foo has an empty `sql` tag")
})

t.Run("missing field", func(t *testing.T) {
rows := mockRows{
columns: []string{"foo", "bar"},
}

type row struct {
Foo int `sql:"foo"`
Bar string
}
fn := func() { _, _ = scan[row](&rows) }
assert.Panics[E](t, fn, `queries: no field for column "bar"`)
})

t.Run("ok", func(t *testing.T) {
rows := mockRows{
columns: []string{"foo", "bar"},
values: []any{1, "A"},
}

type row struct {
Foo int `sql:"foo"`
Bar string `sql:"bar"`
}
r, err := scan[row](&rows)
assert.NoErr[F](t, err)
assert.Equal[E](t, r.Foo, 1)
assert.Equal[E](t, r.Bar, "A")
})
}

type mockRows struct {
columns []string
values []any
}

func (r *mockRows) Columns() ([]string, error) { return r.columns, nil }

func (r *mockRows) Scan(dst ...any) error {
for i := range dst {
v := reflect.ValueOf(r.values[i])
reflect.ValueOf(dst[i]).Elem().Set(v)
}
return nil
}
94 changes: 0 additions & 94 deletions scanner.go

This file was deleted.

Loading
Loading