Skip to content

Commit de67e68

Browse files
jremy42Codelax
andauthored
fix(core/reflect): handle missing values in slice with multiple elements (#3762)
Co-authored-by: Jules Casteran <[email protected]> Co-authored-by: Jules Castéran <[email protected]>
1 parent 6ef60ed commit de67e68

File tree

4 files changed

+203
-15
lines changed

4 files changed

+203
-15
lines changed

internal/core/arg_file_content.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func loadArgsFileContent(cmd *Command, cmdArgs interface{}) error {
1919
}
2020

2121
fieldName := strcase.ToPublicGoName(argSpec.Name)
22-
fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
22+
fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
2323
if err != nil {
2424
continue
2525
}

internal/core/reflect.go

+18-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package core
22

33
import (
4+
"errors"
45
"fmt"
56
"reflect"
67
"sort"
@@ -34,26 +35,33 @@ func newObjectWithForcedJSONTags(t reflect.Type) interface{} {
3435
return reflect.New(reflect.StructOf(structFieldsCopy)).Interface()
3536
}
3637

37-
// getValuesForFieldByName recursively search for fields in a cmdArgs' value and returns its values if they exist.
38+
// GetValuesForFieldByName recursively search for fields in a cmdArgs' value and returns its values if they exist.
3839
// The search is based on the name of the field.
39-
func getValuesForFieldByName(value reflect.Value, parts []string) (values []reflect.Value, err error) {
40+
func GetValuesForFieldByName(value reflect.Value, parts []string) (values []reflect.Value, err error) {
4041
if len(parts) == 0 {
4142
return []reflect.Value{value}, nil
4243
}
43-
4444
switch value.Kind() {
4545
case reflect.Ptr:
46-
return getValuesForFieldByName(value.Elem(), parts)
46+
return GetValuesForFieldByName(value.Elem(), parts)
4747

4848
case reflect.Slice:
4949
values := []reflect.Value(nil)
50+
errs := []error(nil)
51+
5052
for i := 0; i < value.Len(); i++ {
51-
newValues, err := getValuesForFieldByName(value.Index(i), parts[1:])
53+
newValues, err := GetValuesForFieldByName(value.Index(i), parts[1:])
5254
if err != nil {
53-
return nil, err
55+
errs = append(errs, err)
56+
} else {
57+
values = append(values, newValues...)
5458
}
55-
values = append(values, newValues...)
5659
}
60+
61+
if len(values) == 0 && len(errs) != 0 {
62+
return nil, errors.Join(errs...)
63+
}
64+
5765
return values, nil
5866

5967
case reflect.Map:
@@ -70,7 +78,7 @@ func getValuesForFieldByName(value reflect.Value, parts []string) (values []refl
7078

7179
for _, mapKey := range mapKeys {
7280
mapValue := value.MapIndex(mapKey)
73-
newValues, err := getValuesForFieldByName(mapValue, parts[1:])
81+
newValues, err := GetValuesForFieldByName(mapValue, parts[1:])
7482
if err != nil {
7583
return nil, err
7684
}
@@ -93,19 +101,18 @@ func getValuesForFieldByName(value reflect.Value, parts []string) (values []refl
93101

94102
fieldName := strcase.ToPublicGoName(parts[0])
95103
if fieldIndex, exist := fieldIndexByName[fieldName]; exist {
96-
return getValuesForFieldByName(value.Field(fieldIndex), parts[1:])
104+
return GetValuesForFieldByName(value.Field(fieldIndex), parts[1:])
97105
}
98106

99107
// If it does not exist we try to find it in nested anonymous field
100108
for _, fieldIndex := range anonymousFieldIndexes {
101-
newValues, err := getValuesForFieldByName(value.Field(fieldIndex), parts)
109+
newValues, err := GetValuesForFieldByName(value.Field(fieldIndex), parts)
102110
if err == nil {
103111
return newValues, nil
104112
}
105113
}
106114

107115
return nil, fmt.Errorf("field %v does not exist for %v", fieldName, value.Type().Name())
108116
}
109-
110117
return nil, fmt.Errorf("case is not handled")
111118
}

internal/core/reflect_test.go

+181
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
package core_test
2+
3+
import (
4+
"net"
5+
"reflect"
6+
"strings"
7+
"testing"
8+
9+
"github.com/alecthomas/assert"
10+
"github.com/scaleway/scaleway-cli/v2/internal/core"
11+
"github.com/scaleway/scaleway-sdk-go/scw"
12+
)
13+
14+
type RequestEmbedding struct {
15+
EmbeddingField1 string
16+
EmbeddingField2 int
17+
}
18+
19+
type CreateRequest struct {
20+
*RequestEmbedding
21+
CreateField1 string
22+
CreateField2 int
23+
}
24+
25+
type ExtendedRequest struct {
26+
*CreateRequest
27+
ExtendedField1 string
28+
ExtendedField2 int
29+
}
30+
31+
type ArrowRequest struct {
32+
PrivateNetwork *PrivateNetwork
33+
}
34+
35+
type SpecialRequest struct {
36+
*RequestEmbedding
37+
TabRequest []*ArrowRequest
38+
}
39+
40+
type EndpointSpecPrivateNetwork struct {
41+
PrivateNetworkID string
42+
ServiceIP *scw.IPNet
43+
}
44+
45+
type PrivateNetwork struct {
46+
*EndpointSpecPrivateNetwork
47+
OtherValue string
48+
}
49+
50+
func Test_getValuesForFieldByName(t *testing.T) {
51+
type TestCase struct {
52+
cmdArgs interface{}
53+
fieldName string
54+
expectedError string
55+
expectedValues []reflect.Value
56+
}
57+
58+
expectedServiceIP := &scw.IPNet{
59+
IPNet: net.IPNet{
60+
IP: net.ParseIP("192.0.2.1"),
61+
Mask: net.CIDRMask(24, 32),
62+
},
63+
}
64+
65+
tests := []struct {
66+
name string
67+
testCase TestCase
68+
testFunc func(*testing.T, TestCase)
69+
}{
70+
{
71+
name: "Simple test",
72+
testCase: TestCase{
73+
cmdArgs: &ExtendedRequest{
74+
CreateRequest: &CreateRequest{
75+
RequestEmbedding: &RequestEmbedding{
76+
EmbeddingField1: "value1",
77+
EmbeddingField2: 2,
78+
},
79+
CreateField1: "value3",
80+
CreateField2: 4,
81+
},
82+
ExtendedField1: "value5",
83+
ExtendedField2: 6,
84+
},
85+
fieldName: "EmbeddingField1",
86+
expectedError: "",
87+
expectedValues: []reflect.Value{reflect.ValueOf("value1")},
88+
},
89+
testFunc: func(t *testing.T, tc TestCase) {
90+
values, err := core.GetValuesForFieldByName(reflect.ValueOf(tc.cmdArgs), strings.Split(tc.fieldName, "."))
91+
if err != nil {
92+
assert.Equal(t, tc.expectedError, err.Error())
93+
} else {
94+
if tc.expectedValues != nil && !reflect.DeepEqual(tc.expectedValues[0].Interface(), values[0].Interface()) {
95+
t.Errorf("Expected %v, got %v", tc.expectedValues[0].Interface(), values[0].Interface())
96+
}
97+
}
98+
},
99+
},
100+
{
101+
name: "Error test",
102+
testCase: TestCase{
103+
cmdArgs: &ExtendedRequest{
104+
CreateRequest: &CreateRequest{
105+
RequestEmbedding: &RequestEmbedding{
106+
EmbeddingField1: "value1",
107+
EmbeddingField2: 2,
108+
},
109+
CreateField1: "value3",
110+
CreateField2: 4,
111+
},
112+
ExtendedField1: "value5",
113+
ExtendedField2: 6,
114+
},
115+
fieldName: "NotExist",
116+
expectedError: "field NotExist does not exist for ExtendedRequest",
117+
expectedValues: []reflect.Value{reflect.ValueOf("value1")},
118+
},
119+
testFunc: func(t *testing.T, tc TestCase) {
120+
values, err := core.GetValuesForFieldByName(reflect.ValueOf(tc.cmdArgs), strings.Split(tc.fieldName, "."))
121+
if err != nil {
122+
assert.Equal(t, tc.expectedError, err.Error())
123+
} else {
124+
if tc.expectedValues != nil && !reflect.DeepEqual(tc.expectedValues[0].Interface(), values[0].Interface()) {
125+
t.Errorf("Expected %v, got %v", tc.expectedValues[0].Interface(), values[0].Interface())
126+
}
127+
}
128+
},
129+
},
130+
{
131+
132+
name: "Special test",
133+
testCase: TestCase{
134+
cmdArgs: &SpecialRequest{
135+
RequestEmbedding: &RequestEmbedding{
136+
EmbeddingField1: "value1",
137+
EmbeddingField2: 2,
138+
},
139+
TabRequest: []*ArrowRequest{
140+
{
141+
PrivateNetwork: &PrivateNetwork{
142+
EndpointSpecPrivateNetwork: &EndpointSpecPrivateNetwork{
143+
ServiceIP: &scw.IPNet{
144+
IPNet: net.IPNet{
145+
IP: net.ParseIP("192.0.2.1"),
146+
Mask: net.CIDRMask(24, 32),
147+
},
148+
},
149+
},
150+
},
151+
},
152+
{
153+
PrivateNetwork: &PrivateNetwork{
154+
OtherValue: "hello",
155+
},
156+
},
157+
},
158+
},
159+
fieldName: "tabRequest.{index}.privateNetwork.serviceIP",
160+
expectedError: "",
161+
expectedValues: []reflect.Value{reflect.ValueOf(expectedServiceIP)},
162+
},
163+
testFunc: func(t *testing.T, tc TestCase) {
164+
values, err := core.GetValuesForFieldByName(reflect.ValueOf(tc.cmdArgs), strings.Split(tc.fieldName, "."))
165+
if err != nil {
166+
assert.Equal(t, nil, err.Error())
167+
} else {
168+
if tc.expectedValues != nil && !reflect.DeepEqual(tc.expectedValues[0].Interface(), values[0].Interface()) {
169+
t.Errorf("Expected %v, got %v", tc.expectedValues[0].Interface(), values[0].Interface())
170+
}
171+
}
172+
},
173+
},
174+
}
175+
176+
for _, tt := range tests {
177+
t.Run(tt.name, func(t *testing.T) {
178+
tt.testFunc(t, tt.testCase)
179+
})
180+
}
181+
}

internal/core/validate.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func DefaultCommandValidateFunc() CommandValidateFunc {
4545
func validateArgValues(cmd *Command, cmdArgs interface{}) error {
4646
for _, argSpec := range cmd.ArgSpecs {
4747
fieldName := strcase.ToPublicGoName(argSpec.Name)
48-
fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
48+
fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
4949
if err != nil {
5050
logger.Infof("could not validate arg value for '%v': invalid fieldName: %v: %v", argSpec.Name, fieldName, err.Error())
5151
continue
@@ -75,7 +75,7 @@ func validateRequiredArgs(cmd *Command, cmdArgs interface{}, rawArgs args.RawArg
7575
}
7676

7777
fieldName := strcase.ToPublicGoName(arg.Name)
78-
fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
78+
fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
7979
if err != nil {
8080
validationErr := fmt.Errorf("could not validate arg value for '%v': invalid field name '%v': %v", arg.Name, fieldName, err.Error())
8181
if !arg.Required {
@@ -117,7 +117,7 @@ func validateDeprecated(ctx context.Context, cmd *Command, cmdArgs interface{},
117117
deprecatedArgs := cmd.ArgSpecs.GetDeprecated(true)
118118
for _, arg := range deprecatedArgs {
119119
fieldName := strcase.ToPublicGoName(arg.Name)
120-
fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
120+
fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
121121
if err != nil {
122122
validationErr := fmt.Errorf("could not validate arg value for '%v': invalid field name '%v': %v", arg.Name, fieldName, err.Error())
123123
if !arg.Required {

0 commit comments

Comments
 (0)