Skip to content

Commit ec47987

Browse files
committed
Fix walk to support AST struct member walk
1 parent ff6ce86 commit ec47987

File tree

4 files changed

+119
-40
lines changed

4 files changed

+119
-40
lines changed

pkg/util/set/dedup.go

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package set
2+
3+
import (
4+
"sort"
5+
)
6+
7+
func SortDeDup(l []string) []string {
8+
n := len(l)
9+
if n <= 1 {
10+
return l
11+
}
12+
sort.Strings(l)
13+
14+
j := 1
15+
for i := 1; i < n; i++ {
16+
if l[i] != l[i-1] {
17+
l[j] = l[i]
18+
j++
19+
}
20+
}
21+
22+
return l[0:j]
23+
}

pkg/util/set/dedup_test.go

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package set
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
)
7+
8+
func TestSortDeDup(t *testing.T) {
9+
{
10+
l := []string{"c", "a", "b", "c", "e", "d"}
11+
expected := []string{"a", "b", "c", "d", "e"}
12+
sl := SortDeDup(l)
13+
if fmt.Sprint(sl) != fmt.Sprint(expected) {
14+
t.Errorf("%v should be equal %v", sl, expected)
15+
}
16+
}
17+
{
18+
l := []string{"mj.mc_trscode", "mj.rowid", "mj.tans_amt", "mj.trans_bran_code", "mj.trans_date", "mj.trans_flag", "trans_date", "mj.trans_flag", "trans_date"}
19+
expected := []string{"mj.mc_trscode", "mj.rowid", "mj.tans_amt", "mj.trans_bran_code", "mj.trans_date", "mj.trans_flag", "trans_date"}
20+
sl := SortDeDup(l)
21+
if fmt.Sprint(sl) != fmt.Sprint(expected) {
22+
t.Errorf("%v should be equal %v", sl, expected)
23+
}
24+
}
25+
{
26+
l := []string{"c"}
27+
expected := []string{"c"}
28+
sl := SortDeDup(l)
29+
if fmt.Sprint(sl) != fmt.Sprint(expected) {
30+
t.Errorf("%v should be equal %v", sl, expected)
31+
}
32+
}
33+
{
34+
l := make([]string, 0)
35+
expected := []string{}
36+
sl := SortDeDup(l)
37+
if fmt.Sprint(sl) != fmt.Sprint(expected) {
38+
t.Errorf("%v should be equal %v", sl, expected)
39+
}
40+
}
41+
{
42+
var l []string
43+
expected := []string(nil)
44+
sl := SortDeDup(l)
45+
if fmt.Sprint(sl) != fmt.Sprint(expected) {
46+
t.Errorf("%v should be equal %v", sl, expected)
47+
}
48+
}
49+
}

pkg/walk/walker.go

+46-20
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,32 @@ package walk
33
import (
44
"fmt"
55
"log"
6-
"sort"
76
"strings"
87

98
"github.com/auxten/postgresql-parser/pkg/sql/parser"
109
"github.com/auxten/postgresql-parser/pkg/sql/sem/tree"
10+
"github.com/auxten/postgresql-parser/pkg/util/set"
1111
)
1212

1313
type AstWalker struct {
14-
unknownNodes []interface{}
14+
UnknownNodes []interface{}
1515
Fn func(ctx interface{}, node interface{}) (stop bool)
1616
}
1717
type ReferredCols map[string]int
1818

1919
func (rc ReferredCols) ToList() []string {
2020
cols := make([]string, len(rc))
2121
i := 0
22-
for k, _ := range rc {
22+
for k := range rc {
2323
cols[i] = k
2424
i++
2525
}
26-
sort.Strings(cols)
27-
return cols
26+
return set.SortDeDup(cols)
2827
}
2928

3029
func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err error) {
3130

32-
w.unknownNodes = make([]interface{}, 0)
31+
w.UnknownNodes = make([]interface{}, 0)
3332
asts := make([]tree.NodeFormatter, len(stmts))
3433
for si, stmt := range stmts {
3534
asts[si] = stmt.AST
@@ -69,8 +68,8 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
6968
walk(node.Left, node.Right)
7069
case *tree.CaseExpr:
7170
walk(node.Expr, node.Else)
72-
for _, w := range node.Whens {
73-
walk(w.Cond, w.Val)
71+
for _, when := range node.Whens {
72+
walk(when.Cond, when.Val)
7473
}
7574
case *tree.RangeCond:
7675
walk(node.Left, node.From, node.To)
@@ -98,6 +97,11 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
9897
walk(expr)
9998
}
10099
case *tree.FamilyTableDef:
100+
case *tree.From:
101+
walk(node.AsOf)
102+
for _, table := range node.Tables {
103+
walk(table)
104+
}
101105
case *tree.FuncExpr:
102106
if node.WindowDef != nil {
103107
walk(node.WindowDef)
@@ -111,6 +115,12 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
111115
case *tree.NumVal:
112116
case *tree.OnJoinCond:
113117
walk(node.Expr)
118+
case *tree.Order:
119+
walk(node.Expr, node.Table)
120+
case tree.OrderBy:
121+
for _, order := range node {
122+
walk(order)
123+
}
114124
case *tree.OrExpr:
115125
walk(node.Left, node.Right)
116126
case *tree.ParenExpr:
@@ -126,16 +136,12 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
126136
walk(node.With)
127137
}
128138
if node.OrderBy != nil {
129-
for _, order := range node.OrderBy {
130-
walk(order)
131-
}
139+
walk(node.OrderBy)
132140
}
133141
if node.Limit != nil {
134142
walk(node.Limit)
135143
}
136144
walk(node.Select)
137-
case *tree.Order:
138-
walk(node.Expr, node.Table)
139145
case *tree.Limit:
140146
walk(node.Count)
141147
case *tree.SelectClause:
@@ -156,10 +162,7 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
156162
walk(group)
157163
}
158164
}
159-
walk(node.From.AsOf)
160-
for _, table := range node.From.Tables {
161-
walk(table)
162-
}
165+
walk(&node.From)
163166
case tree.SelectExpr:
164167
walk(node.Expr)
165168
case tree.SelectExprs:
@@ -173,6 +176,10 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
173176
case *tree.StrVal:
174177
case *tree.Subquery:
175178
walk(node.Select)
179+
case tree.TableExprs:
180+
for _, expr := range node {
181+
walk(expr)
182+
}
176183
case *tree.TableName, tree.TableName:
177184
case *tree.Tuple:
178185
for _, expr := range node.Exprs {
@@ -214,8 +221,8 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
214221
walk(expr)
215222
}
216223
default:
217-
if w.unknownNodes != nil {
218-
w.unknownNodes = append(w.unknownNodes, node)
224+
if w.UnknownNodes != nil {
225+
w.UnknownNodes = append(w.UnknownNodes, node)
219226
}
220227
}
221228
}
@@ -270,8 +277,27 @@ func ColNamesInSelect(sql string) (referredCols ReferredCols, err error) {
270277
if err != nil {
271278
return
272279
}
273-
for _, col := range w.unknownNodes {
280+
for _, col := range w.UnknownNodes {
274281
log.Printf("unhandled column type %T", col)
275282
}
276283
return
277284
}
285+
286+
func AllColsContained(set ReferredCols, cols []string) bool {
287+
if cols == nil {
288+
if set == nil {
289+
return true
290+
} else {
291+
return false
292+
}
293+
}
294+
if len(set) != len(cols) {
295+
return false
296+
}
297+
for _, col := range cols {
298+
if _, exist := set[col]; !exist {
299+
return false
300+
}
301+
}
302+
return true
303+
}

pkg/walk/walker_test.go

+1-20
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,6 @@ func TestParser(t *testing.T) {
3434
})
3535
}
3636

37-
func allColsContained(set ReferredCols, cols []string) bool {
38-
if cols == nil {
39-
if set == nil {
40-
return true
41-
} else {
42-
return false
43-
}
44-
}
45-
if len(set) != len(cols) {
46-
return false
47-
}
48-
for _, col := range cols {
49-
if _, exist := set[col]; !exist {
50-
return false
51-
}
52-
}
53-
return true
54-
}
55-
5637
func TestReferredVarsInSelectStatement(t *testing.T) {
5738
testCases := []struct {
5839
sql string
@@ -191,7 +172,7 @@ func TestReferredVarsInSelectStatement(t *testing.T) {
191172
}
192173

193174
for _, tc := range testCases {
194-
t.Run(tc.sql, func(t *testing.T) {
175+
t.Run(tc.sql, func(t *testing.T) {
195176
referredCols, err := func() (ReferredCols, error) {
196177
return ColNamesInSelect(tc.sql)
197178
}()

0 commit comments

Comments
 (0)