diff --git a/example_test.go b/example_test.go index ad7a162..0ad3a1c 100644 --- a/example_test.go +++ b/example_test.go @@ -35,28 +35,25 @@ func run(ctx context.Context) error { columns = append(columns, "created_at") } + // select first_name, last_name, created_at from users where created_at >= $1 var qb queries.Builder qb.Appendf("select %s from users", strings.Join(columns, ", ")) if true { qb.Appendf(" where created_at >= %$", time.Date(2024, time.January, 1, 0, 0, 0, 0, time.Local)) } - // select first_name, last_name, created_at from users where created_at >= $1 - rows, err := db.QueryContext(ctx, qb.Query(), qb.Args()...) - if err != nil { - return err - } - defer rows.Close() - - var users []struct { + type user struct { FirstName string `sql:"first_name"` LastName string `sql:"last_name"` CreatedAt time.Time `sql:"created_at"` } - if err := queries.Scan(&users, rows); err != nil { - return err + + for user, err := range queries.Query[user](ctx, db, qb.Query(), qb.Args()...) { + if err != nil { + return err + } + fmt.Println(user) } - fmt.Println(users) return nil } diff --git a/go.mod b/go.mod index 722e4f3..7470367 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module go-simpler.org/queries -go 1.22 +go 1.23 diff --git a/query.go b/query.go new file mode 100644 index 0000000..4f6e673 --- /dev/null +++ b/query.go @@ -0,0 +1,124 @@ +package queries + +import ( + "context" + "database/sql" + "fmt" + "iter" + "reflect" +) + +type queryer interface { + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) +} + +func Query[T any](ctx context.Context, q queryer, query string, args ...any) iter.Seq2[T, error] { + return func(yield func(T, error) bool) { + rows, err := q.QueryContext(ctx, query, args...) + if err != nil { + yield(zero[T](), err) + return + } + defer rows.Close() + + for rows.Next() { + t, err := scan[T](rows) + if err != nil { + yield(zero[T](), err) + return + } + if !yield(t, nil) { + return + } + } + if err := rows.Err(); err != nil { + yield(zero[T](), err) + return + } + } +} + +func QueryRow[T any](ctx context.Context, q queryer, query string, args ...any) (T, error) { + rows, err := q.QueryContext(ctx, query, args...) + if err != nil { + return zero[T](), err + } + defer rows.Close() + + if !rows.Next() { + if err := rows.Err(); err != nil { + return zero[T](), err + } + return zero[T](), sql.ErrNoRows + } + + t, err := scan[T](rows) + if err != nil { + return zero[T](), err + } + if err := rows.Err(); err != nil { + return zero[T](), err + } + + return t, nil +} + +func zero[T any]() (t T) { return t } + +type rows interface { + Columns() ([]string, error) + Scan(...any) error +} + +func scan[T any](rows rows) (T, error) { + var t T + v := reflect.ValueOf(&t).Elem() + if v.Kind() != reflect.Struct { + panic("queries: T must be a struct") + } + + columns, err := rows.Columns() + if err != nil { + return zero[T](), fmt.Errorf("getting column names: %w", err) + } + + fields := parseStruct(v) + args := make([]any, len(columns)) + + for i, column := range columns { + field, ok := fields[column] + if !ok { + panic(fmt.Sprintf("queries: no field for column %q", column)) + } + args[i] = field + } + if err := rows.Scan(args...); err != nil { + return zero[T](), err + } + + return t, nil +} + +// TODO: add sync.Map cache. +func parseStruct(v reflect.Value) map[string]any { + fields := make(map[string]any, v.NumField()) + + for i := range v.NumField() { + field := v.Field(i) + if !field.CanSet() { + continue + } + + tag, ok := v.Type().Field(i).Tag.Lookup("sql") + if !ok { + continue + } + if tag == "" { + panic(fmt.Sprintf("queries: field %s has an empty `sql` tag", v.Type().Field(i).Name)) + } + + fields[tag] = field.Addr().Interface() + } + + return fields +} diff --git a/query_test.go b/query_test.go new file mode 100644 index 0000000..aaf4772 --- /dev/null +++ b/query_test.go @@ -0,0 +1,68 @@ +package queries + +import ( + "reflect" + "testing" + + "go-simpler.org/queries/internal/assert" + . "go-simpler.org/queries/internal/assert/EF" +) + +func Test_scan(t *testing.T) { + t.Run("non-struct T", func(t *testing.T) { + fn := func() { _, _ = scan[int](new(mockRows)) } + assert.Panics[E](t, fn, "queries: T must be a struct") + }) + + t.Run("empty tag", func(t *testing.T) { + type row struct { + Foo int `sql:""` + } + fn := func() { _, _ = scan[row](new(mockRows)) } + assert.Panics[E](t, fn, "queries: field Foo has an empty `sql` tag") + }) + + t.Run("missing field", func(t *testing.T) { + rows := mockRows{ + columns: []string{"foo", "bar"}, + } + + type row struct { + Foo int `sql:"foo"` + Bar string + } + fn := func() { _, _ = scan[row](&rows) } + assert.Panics[E](t, fn, `queries: no field for column "bar"`) + }) + + t.Run("ok", func(t *testing.T) { + rows := mockRows{ + columns: []string{"foo", "bar"}, + values: []any{1, "A"}, + } + + type row struct { + Foo int `sql:"foo"` + Bar string `sql:"bar"` + } + r, err := scan[row](&rows) + assert.NoErr[F](t, err) + assert.Equal[E](t, r.Foo, 1) + assert.Equal[E](t, r.Bar, "A") + }) +} + +type mockRows struct { + columns []string + values []any +} + +func (r *mockRows) Columns() ([]string, error) { return r.columns, nil } + +func (r *mockRows) Scan(dst ...any) error { + for i := range dst { + v := reflect.ValueOf(r.values[i]) + reflect.ValueOf(dst[i]).Elem().Set(v) + } + return nil +} diff --git a/scanner.go b/scanner.go deleted file mode 100644 index 27fb3fc..0000000 --- a/scanner.go +++ /dev/null @@ -1,94 +0,0 @@ -package queries - -import ( - "database/sql" - "fmt" - "reflect" -) - -type Rows interface { - Columns() ([]string, error) - Next() bool - Scan(...any) error - Err() error -} - -func Scan[T any](dst *[]T, rows Rows) error { - return scan[T](reflect.ValueOf(dst).Elem(), rows) -} - -func ScanRow[T any](dst *T, rows Rows) error { - return scan[T](reflect.ValueOf(dst).Elem(), rows) -} - -func scan[T any](v reflect.Value, rows Rows) error { - typ := reflect.TypeFor[T]() - if typ.Kind() != reflect.Struct { - panic("queries: T must be a struct") - } - - strct := reflect.New(typ).Elem() - fields := parseStruct(strct) - - columns, err := rows.Columns() - if err != nil { - return fmt.Errorf("getting column names: %w", err) - } - - into := make([]any, len(columns)) - for i, column := range columns { - field, ok := fields[column] - if !ok { - panic(fmt.Sprintf("queries: no field for column %q", column)) - } - into[i] = field - } - - slice := reflect.New(reflect.SliceOf(typ)).Elem() - for rows.Next() { - if err := rows.Scan(into...); err != nil { - return fmt.Errorf("scanning row: %w", err) - } - slice = reflect.Append(slice, strct) - } - if err := rows.Err(); err != nil { - return err - } - - switch v.Kind() { - case reflect.Slice: - v.Set(slice) - case reflect.Struct: - if slice.Len() == 0 { - return sql.ErrNoRows - } - v.Set(slice.Index(0)) - default: - panic("unreachable") - } - - return nil -} - -func parseStruct(v reflect.Value) map[string]any { - fields := make(map[string]any, v.NumField()) - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - if !field.CanSet() { - continue - } - - tag, ok := v.Type().Field(i).Tag.Lookup("sql") - if !ok { - continue - } - if tag == "" { - panic(fmt.Sprintf("queries: field %s has an empty `sql` tag", v.Type().Field(i).Name)) - } - - fields[tag] = field.Addr().Interface() - } - - return fields -} diff --git a/scanner_test.go b/scanner_test.go deleted file mode 100644 index 2603127..0000000 --- a/scanner_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package queries_test - -import ( - "database/sql" - "reflect" - "testing" - - "go-simpler.org/queries" - "go-simpler.org/queries/internal/assert" - . "go-simpler.org/queries/internal/assert/EF" -) - -func Test_misuse(t *testing.T) { - t.Run("non-struct T", func(t *testing.T) { - const panicMsg = "queries: T must be a struct" - - assert.Panics[E](t, func() { _ = queries.Scan(new([]int), nil) }, panicMsg) - assert.Panics[E](t, func() { _ = queries.ScanRow(new(int), nil) }, panicMsg) - }) - - t.Run("empty tag", func(t *testing.T) { - const panicMsg = "queries: field Foo has an empty `sql` tag" - - type dst struct { - Foo int `sql:""` - } - assert.Panics[E](t, func() { _ = queries.Scan(new([]dst), nil) }, panicMsg) - assert.Panics[E](t, func() { _ = queries.ScanRow(new(dst), nil) }, panicMsg) - }) - - t.Run("missing field", func(t *testing.T) { - const panicMsg = `queries: no field for column "foo"` - - rows := mockRows{columns: []string{"foo"}} - - type dst struct { - Foo int - } - assert.Panics[E](t, func() { _ = queries.Scan(new([]dst), &rows) }, panicMsg) - assert.Panics[E](t, func() { _ = queries.ScanRow(new(dst), &rows) }, panicMsg) - }) -} - -func TestScan(t *testing.T) { - rows := mockRows{ - columns: []string{"foo", "bar"}, - values: [][]any{{1, "A"}, {2, "B"}}, - } - - var dst []struct { - Foo int `sql:"foo"` - Bar string `sql:"bar"` - } - err := queries.Scan(&dst, &rows) - assert.NoErr[F](t, err) - assert.Equal[E](t, len(dst), 2) - assert.Equal[E](t, dst[0].Foo, 1) - assert.Equal[E](t, dst[1].Foo, 2) - assert.Equal[E](t, dst[0].Bar, "A") - assert.Equal[E](t, dst[1].Bar, "B") -} - -func TestScanRow(t *testing.T) { - rows := mockRows{ - columns: []string{"foo", "bar"}, - values: [][]any{{1, "A"}}, - } - - var dst struct { - Foo int `sql:"foo"` - Bar string `sql:"bar"` - } - err := queries.ScanRow(&dst, &rows) - assert.NoErr[F](t, err) - assert.Equal[E](t, dst.Foo, 1) - assert.Equal[E](t, dst.Bar, "A") - - t.Run("no rows", func(t *testing.T) { - rows := mockRows{columns: []string{"foo"}} - - var dst struct { - Foo int `sql:"foo"` - } - err := queries.ScanRow(&dst, &rows) - assert.IsErr[E](t, err, sql.ErrNoRows) - }) -} - -type mockRows struct { - columns []string - values [][]any - idx int -} - -func (r *mockRows) Columns() ([]string, error) { return r.columns, nil } -func (r *mockRows) Next() bool { return r.idx < len(r.values) } -func (r *mockRows) Err() error { return nil } - -func (r *mockRows) Scan(dst ...any) error { - for i := 0; i < len(dst); i++ { - v := reflect.ValueOf(r.values[r.idx][i]) - reflect.ValueOf(dst[i]).Elem().Set(v) - } - r.idx++ - return nil -}