Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add flags support #50

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 65 additions & 14 deletions env.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,35 @@
package env

import (
"flag"
"fmt"
"os"
"reflect"
"strconv"
"strings"
)

// Options are the options for the [Load] function.
type Options struct {
Source Source // The source of environment variables. The default is [OS].
SliceSep string // The separator used to parse slice values. The default is space.
Source Source // The source of environment variables. The default is [OS].
SliceSep string // The separator used to parse slice values. The default is space.
FlagSet *flag.FlagSet // ...
FlagArgs []string // ...
}

func (o *Options) init() {
if o.Source == nil {
o.Source = OS
}
if o.SliceSep == "" {
o.SliceSep = " "
}
if o.FlagSet == nil {
o.FlagSet = flag.CommandLine
}
if o.FlagArgs == nil {
o.FlagArgs = os.Args[1:]
}
}

// NotSetError is returned when environment variables are marked as required but not set.
Expand Down Expand Up @@ -57,12 +76,7 @@ func Load(cfg any, opts *Options) error {
if opts == nil {
opts = new(Options)
}
if opts.Source == nil {
opts.Source = OS
}
if opts.SliceSep == "" {
opts.SliceSep = " "
}
opts.init()

pv := reflect.ValueOf(cfg)
if !structPtr(pv) {
Expand All @@ -73,10 +87,34 @@ func Load(cfg any, opts *Options) error {
vars := parseVars(v)
cache[v.Type()] = vars

for _, v := range vars {
if v.Flag == "" {
continue
}
if v.Type.Kind() != reflect.Bool {
opts.FlagSet.String(v.Flag, v.Default, v.Usage)
continue
}
// handle flags without a value, e.g. -help.
value, err := strconv.ParseBool(v.Default) // TODO: default may be empty.
if err != nil {
return fmt.Errorf("parsing bool: %w", err)
}
opts.FlagSet.Bool(v.Flag, value, v.Usage)
}

if err := opts.FlagSet.Parse(opts.FlagArgs); err != nil {
return fmt.Errorf("parsing flags: %w", err)
}

var notset []string
for _, v := range vars {
value, ok := lookupEnv(opts.Source, v.Name, v.Expand)
if !ok {
value, envSet := lookupEnv(opts.Source, v.Name, v.Expand)
flagValue, flagSet := lookupFlag(opts.FlagSet, v.Flag)
if flagSet {
value = flagValue // flags have higher priority.
}
if !envSet && !flagSet {
if v.Required {
notset = append(notset, v.Name)
continue
Expand Down Expand Up @@ -120,8 +158,8 @@ func parseVars(v reflect.Value) []Var {
continue
}

sf := v.Type().Field(i)
value, ok := sf.Tag.Lookup("env")
tags := v.Type().Field(i).Tag
value, ok := tags.Lookup("env")
if !ok {
continue
}
Expand All @@ -144,7 +182,7 @@ func parseVars(v reflect.Value) []Var {
}
}

defValue, defSet := sf.Tag.Lookup("default")
defValue, defSet := tags.Lookup("default")
switch {
case defSet && required:
panic("env: `required` and `default` can't be used simultaneously")
Expand All @@ -155,7 +193,8 @@ func parseVars(v reflect.Value) []Var {
vars = append(vars, Var{
Name: name,
Type: field.Type(),
Usage: sf.Tag.Get("usage"),
Flag: tags.Get("flag"),
Usage: tags.Get("usage"),
Default: defValue,
Required: required,
Expand: expand,
Expand All @@ -181,3 +220,15 @@ func lookupEnv(src Source, key string, expand bool) (string, bool) {
}
return os.Expand(value, mapping), true
}

func lookupFlag(fs *flag.FlagSet, name string) (string, bool) {
var value string
var set bool
fs.Visit(func(fl *flag.Flag) {
if fl.Name == name {
value = fl.Value.String()
set = true
}
})
return value, set
}
1 change: 1 addition & 0 deletions usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var cache = make(map[reflect.Type][]Var)
type Var struct {
Name string // The name of the variable.
Type reflect.Type // The type of the variable.
Flag string // The flag to overwrite the variable.
Usage string // The usage string parsed from the `usage` tag (if exists).
Default string // The default value of the variable. Empty, if the variable is required.
Required bool // True, if the variable is marked as required.
Expand Down