Skip to content

Commit

Permalink
User-friendly error for invalid envs/flags (#3414)
Browse files Browse the repository at this point in the history
* fix: user-friendly error for invalid envs

* chore: code improvements

* chore: code comment
  • Loading branch information
levkohimins authored Sep 17, 2024
1 parent db3e671 commit 814fb36
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 18 deletions.
16 changes: 14 additions & 2 deletions pkg/cli/bool_flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cli
import (
libflag "flag"

"github.com/gruntwork-io/go-commons/errors"
"github.com/urfave/cli/v2"
)

Expand Down Expand Up @@ -42,11 +43,22 @@ func (flag *BoolFlag) Apply(set *libflag.FlagSet) error {
flag.Destination = new(bool)
}

var err error
var (
err error
envValue *string
)

valType := FlagType[bool](&boolFlagType{negative: flag.Negative})

if flag.FlagValue, err = newGenericValue(valType, flag.LookupEnv(flag.EnvVar), flag.Destination); err != nil {
if val := flag.LookupEnv(flag.EnvVar); val != nil && *val != "" {
envValue = val
}

if flag.FlagValue, err = newGenericValue(valType, envValue, flag.Destination); err != nil {
if envValue != nil {
return errors.Errorf("invalid boolean value %q for %s: %w", *envValue, flag.EnvVar, err)
}

return err
}

Expand Down
24 changes: 21 additions & 3 deletions pkg/cli/bool_flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/gruntwork-io/terragrunt/pkg/cli"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
)

func TestBoolFlagApply(t *testing.T) {
Expand Down Expand Up @@ -86,6 +87,20 @@ func TestBoolFlagApply(t *testing.T) {
false,
errors.New(`invalid boolean flag foo: setting the flag multiple times`),
},
{
cli.BoolFlag{Name: "foo", EnvVar: "FOO"},
nil,
map[string]string{"FOO": ""},
false,
nil,
},
{
cli.BoolFlag{Name: "foo", EnvVar: "FOO"},
nil,
map[string]string{"FOO": "monkey"},
false,
errors.New(`invalid boolean value "monkey" for FOO: must be one of: "0", "1", "f", "t", "false", "true"`),
},
}

for i, testCase := range testCases {
Expand Down Expand Up @@ -130,11 +145,12 @@ func testBoolFlagApply(t *testing.T, flag *cli.BoolFlag, args []string, envs map
flagSet.SetOutput(io.Discard)

err := flag.Apply(flagSet)
require.NoError(t, err)
if err == nil {
err = flagSet.Parse(args)
}

err = flagSet.Parse(args)
if expectedErr != nil {
require.Equal(t, expectedErr, err)
require.ErrorContains(t, expectedErr, err.Error())
return
}
require.NoError(t, err)
Expand All @@ -148,6 +164,8 @@ func testBoolFlagApply(t *testing.T, flag *cli.BoolFlag, args []string, envs map
assert.Equal(t, strconv.FormatBool(expectedValue), flag.GetValue(), "GetValue()")
}

maps.DeleteFunc(envs, func(k, v string) bool { return v == "" })

assert.Equal(t, len(args) > 0 || len(envs) > 0, flag.Value().IsSet(), "IsSet()")
assert.Equal(t, expectedDefaultValue, flag.Value().GetDefaultText(), "GetDefaultText()")

Expand Down
14 changes: 14 additions & 0 deletions pkg/cli/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,17 @@ func handleExitCoder(err error, osExiter func(code int)) error {

return err
}

// InvalidValueError is used to wrap errors from `strconv` to make the error message more user friendly.
type InvalidValueError struct {
underlyingError error
msg string
}

func (err InvalidValueError) Error() string {
return err.msg
}

func (err InvalidValueError) Unwrap() error {
return err.underlyingError
}
23 changes: 17 additions & 6 deletions pkg/cli/generic_flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,22 @@ func (flag *GenericFlag[T]) Apply(set *libflag.FlagSet) error {
flag.Destination = new(T)
}

var err error
var (
err error
envValue *string
)

valType := FlagType[T](new(genericType[T]))

if flag.FlagValue, err = newGenericValue(valType, flag.LookupEnv(flag.EnvVar), flag.Destination); err != nil {
if val := flag.LookupEnv(flag.EnvVar); val != nil {
envValue = val
}

if flag.FlagValue, err = newGenericValue(valType, envValue, flag.Destination); err != nil {
if envValue != nil {
return errors.Errorf("invalid value %q for %s: %w", *envValue, flag.EnvVar, err)
}

return err
}

Expand Down Expand Up @@ -199,31 +210,31 @@ func (val *genericType[T]) Set(str string) error {
case *bool:
v, err := strconv.ParseBool(str)
if err != nil {
return errors.Errorf("error parse: %w", err)
return errors.WithStackTrace(InvalidValueError{underlyingError: err, msg: `must be one of: "0", "1", "f", "t", "false", "true"`})
}

*dest = v

case *int:
v, err := strconv.ParseInt(str, 0, strconv.IntSize)
if err != nil {
return errors.Errorf("error parse: %w", err)
return errors.WithStackTrace(InvalidValueError{underlyingError: err, msg: "must be 32-bit integer"})
}

*dest = int(v)

case *uint:
v, err := strconv.ParseUint(str, 10, 64)
if err != nil {
return errors.Errorf("error parse: %w", err)
return errors.WithStackTrace(InvalidValueError{underlyingError: err, msg: "must be 32-bit unsigned integer"})
}

*dest = uint(v)

case *int64:
v, err := strconv.ParseInt(str, 0, 64)
if err != nil {
return errors.Errorf("error parse: %w", err)
return errors.WithStackTrace(InvalidValueError{underlyingError: err, msg: "must be 64-bit integer"})
}

*dest = v
Expand Down
21 changes: 18 additions & 3 deletions pkg/cli/generic_flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ func TestGenericFlagIntApply(t *testing.T) {
20,
nil,
},
{
cli.GenericFlag[int]{Name: "foo", EnvVar: "FOO"},
[]string{},
map[string]string{"FOO": "monkey"},
0,
errors.New(`invalid value "monkey" for FOO: must be 32-bit integer`),
},
{
cli.GenericFlag[int]{Name: "foo", Destination: mockDestValue(55)},
nil,
Expand Down Expand Up @@ -145,6 +152,13 @@ func TestGenericFlagInt64Apply(t *testing.T) {
20,
nil,
},
{
cli.GenericFlag[int64]{Name: "foo", EnvVar: "FOO"},
[]string{},
map[string]string{"FOO": "monkey"},
0,
errors.New(`invalid value "monkey" for FOO: must be 64-bit integer`),
},
{
cli.GenericFlag[int64]{Name: "foo", Destination: mockDestValue(int64(55))},
nil,
Expand Down Expand Up @@ -196,11 +210,12 @@ func testGenericFlagApply[T cli.GenericType](t *testing.T, flag *cli.GenericFlag
flagSet.SetOutput(io.Discard)

err := flag.Apply(flagSet)
require.NoError(t, err)
if err == nil {
err = flagSet.Parse(args)
}

err = flagSet.Parse(args)
if expectedErr != nil {
require.Equal(t, expectedErr, err)
require.ErrorContains(t, expectedErr, err.Error())
return
}
require.NoError(t, err)
Expand Down
15 changes: 13 additions & 2 deletions pkg/cli/map_flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,23 @@ func (flag *MapFlag[K, V]) Apply(set *libflag.FlagSet) error {
flag.KeyValSep = MapFlagKeyValSep
}

var err error
var (
err error
envValue *string
)

keyType := FlagType[K](new(genericType[K]))
valType := FlagType[V](new(genericType[V]))

if flag.FlagValue, err = newMapValue(keyType, valType, flag.LookupEnv(flag.EnvVar), flag.EnvVarSep, flag.KeyValSep, flag.Splitter, flag.Destination); err != nil {
if val := flag.LookupEnv(flag.EnvVar); val != nil {
envValue = val
}

if flag.FlagValue, err = newMapValue(keyType, valType, envValue, flag.EnvVarSep, flag.KeyValSep, flag.Splitter, flag.Destination); err != nil {
if envValue != nil {
return errors.Errorf("invalid value %q for %s: %w", *envValue, flag.EnvVar, err)
}

return err
}

Expand Down
16 changes: 14 additions & 2 deletions pkg/cli/slice_flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
libflag "flag"
"strings"

"github.com/gruntwork-io/go-commons/errors"
"github.com/urfave/cli/v2"
)

Expand Down Expand Up @@ -62,11 +63,22 @@ func (flag *SliceFlag[T]) Apply(set *libflag.FlagSet) error {
flag.EnvVarSep = SliceFlagEnvVarSep
}

var err error
var (
err error
envValue *string
)

valType := FlagType[T](new(genericType[T]))

if flag.FlagValue, err = newSliceValue(valType, flag.LookupEnv(flag.EnvVar), flag.EnvVarSep, flag.Splitter, flag.Destination); err != nil {
if val := flag.LookupEnv(flag.EnvVar); val != nil {
envValue = val
}

if flag.FlagValue, err = newSliceValue(valType, envValue, flag.EnvVarSep, flag.Splitter, flag.Destination); err != nil {
if envValue != nil {
return errors.Errorf("invalid value %q for %s: %w", *envValue, flag.EnvVar, err)
}

return err
}

Expand Down

0 comments on commit 814fb36

Please sign in to comment.