Skip to content

Commit 8ff85cb

Browse files
ir: Fix nil pointer deref in Unmarshal() when handling IsSetStmt (#7430)
Fixes: #7415 Signed-off-by: Kris Kennaway <[email protected]>
1 parent f6c20b0 commit 8ff85cb

File tree

2 files changed

+81
-48
lines changed

2 files changed

+81
-48
lines changed

v1/ir/encoding/encoding_test.go

+69-46
Original file line numberDiff line numberDiff line change
@@ -11,62 +11,85 @@ import (
1111
)
1212

1313
func TestRoundTrip(t *testing.T) {
14+
tests := []struct {
15+
note string
16+
modules map[string]string
17+
}{
18+
{
19+
note: "simple",
20+
modules: map[string]string{
21+
"test.rego": `
22+
package test
23+
p if {
24+
input.foo == 7
25+
}
26+
`,
27+
},
28+
},
29+
{
30+
note: "every",
31+
modules: map[string]string{
32+
"test.rego": `
33+
package test
34+
p if {
35+
every i in input.foo { i > 0 }
36+
}
37+
`,
38+
},
39+
},
40+
}
1441

15-
// Note: v1 module
16-
c, err := ast.CompileModules(map[string]string{
17-
"test.rego": `
18-
package test
42+
for _, tc := range tests {
43+
t.Run(tc.note, func(t *testing.T) {
44+
// Note: v1 module
45+
c, err := ast.CompileModules(tc.modules)
1946

20-
p if {
21-
input.foo == 7
47+
if err != nil {
48+
t.Fatal(err)
2249
}
23-
`,
24-
})
25-
26-
if err != nil {
27-
t.Fatal(err)
28-
}
2950

30-
modules := []*ast.Module{}
51+
modules := []*ast.Module{}
3152

32-
for _, m := range c.Modules {
33-
modules = append(modules, m)
34-
}
53+
for _, m := range c.Modules {
54+
modules = append(modules, m)
55+
}
3556

36-
planner := planner.New().
37-
WithQueries([]planner.QuerySet{
38-
{
39-
Name: "main",
40-
Queries: []ast.Body{
41-
ast.MustParseBody("data.test.p = true"),
42-
},
43-
},
44-
}).
45-
WithModules(modules).
46-
WithBuiltinDecls(ast.BuiltinMap)
57+
planner := planner.New().
58+
WithQueries([]planner.QuerySet{
59+
{
60+
Name: "main",
61+
Queries: []ast.Body{
62+
ast.MustParseBody("data.test.p = true"),
63+
},
64+
},
65+
}).
66+
WithModules(modules).
67+
WithBuiltinDecls(ast.BuiltinMap)
4768

48-
plan, err := planner.Plan()
49-
if err != nil {
50-
t.Fatal(err)
51-
}
69+
plan, err := planner.Plan()
70+
if err != nil {
71+
t.Fatal(err)
72+
}
5273

53-
bs, err := json.MarshalIndent(plan, "", " ")
54-
if err != nil {
55-
t.Fatal(err)
56-
}
74+
bs, err := json.MarshalIndent(plan, "", " ")
75+
if err != nil {
76+
t.Fatal(err)
77+
}
5778

58-
var cpy ir.Policy
59-
err = json.Unmarshal(bs, &cpy)
60-
if err != nil {
61-
t.Fatal(err)
62-
}
79+
var cpy ir.Policy
80+
err = json.Unmarshal(bs, &cpy)
81+
if err != nil {
82+
t.Fatal(err)
83+
}
6384

64-
bs2, err := json.MarshalIndent(plan, "", " ")
65-
if err != nil {
66-
t.Fatal(err)
67-
}
85+
bs2, err := json.MarshalIndent(plan, "", " ")
86+
if err != nil {
87+
t.Fatal(err)
88+
}
6889

69-
if !bytes.Equal(bs, bs2) {
70-
t.Fatal("expected bytes to be equal")
90+
if !bytes.Equal(bs, bs2) {
91+
t.Fatal("expected bytes to be equal")
92+
}
93+
})
7194
}
7295
}

v1/ir/marshal.go

+12-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package ir
66

77
import (
88
"encoding/json"
9+
"fmt"
910
"reflect"
1011
)
1112

@@ -50,7 +51,11 @@ func (a *Operand) UnmarshalJSON(bs []byte) error {
5051
if err := json.Unmarshal(bs, &typed); err != nil {
5152
return err
5253
}
53-
x := valFactories[typed.Type]()
54+
f, ok := valFactories[typed.Type]
55+
if !ok {
56+
return fmt.Errorf("unrecognized value type %q", typed.Type)
57+
}
58+
x := f()
5459
if err := json.Unmarshal(typed.Value, &x); err != nil {
5560
return err
5661
}
@@ -77,7 +82,11 @@ type rawTypedStmt struct {
7782
}
7883

7984
func (raw rawTypedStmt) Unmarshal() (Stmt, error) {
80-
x := stmtFactories[raw.Type]()
85+
f, ok := stmtFactories[raw.Type]
86+
if !ok {
87+
return nil, fmt.Errorf("unrecognized statement type %q", raw.Type)
88+
}
89+
x := f()
8190
if err := json.Unmarshal(raw.Stmt, &x); err != nil {
8291
return nil, err
8392
}
@@ -119,6 +128,7 @@ var stmtFactories = map[string]func() Stmt{
119128
"IsArrayStmt": func() Stmt { return &IsArrayStmt{} },
120129
"IsObjectStmt": func() Stmt { return &IsObjectStmt{} },
121130
"IsDefinedStmt": func() Stmt { return &IsDefinedStmt{} },
131+
"IsSetStmt": func() Stmt { return &IsSetStmt{} },
122132
"IsUndefinedStmt": func() Stmt { return &IsUndefinedStmt{} },
123133
"ArrayAppendStmt": func() Stmt { return &ArrayAppendStmt{} },
124134
"ObjectInsertStmt": func() Stmt { return &ObjectInsertStmt{} },

0 commit comments

Comments
 (0)