@@ -3,33 +3,32 @@ package walk
3
3
import (
4
4
"fmt"
5
5
"log"
6
- "sort"
7
6
"strings"
8
7
9
8
"github.com/auxten/postgresql-parser/pkg/sql/parser"
10
9
"github.com/auxten/postgresql-parser/pkg/sql/sem/tree"
10
+ "github.com/auxten/postgresql-parser/pkg/util/set"
11
11
)
12
12
13
13
type AstWalker struct {
14
- unknownNodes []interface {}
14
+ UnknownNodes []interface {}
15
15
Fn func (ctx interface {}, node interface {}) (stop bool )
16
16
}
17
17
type ReferredCols map [string ]int
18
18
19
19
func (rc ReferredCols ) ToList () []string {
20
20
cols := make ([]string , len (rc ))
21
21
i := 0
22
- for k , _ := range rc {
22
+ for k := range rc {
23
23
cols [i ] = k
24
24
i ++
25
25
}
26
- sort .Strings (cols )
27
- return cols
26
+ return set .SortDeDup (cols )
28
27
}
29
28
30
29
func (w * AstWalker ) Walk (stmts parser.Statements , ctx interface {}) (ok bool , err error ) {
31
30
32
- w .unknownNodes = make ([]interface {}, 0 )
31
+ w .UnknownNodes = make ([]interface {}, 0 )
33
32
asts := make ([]tree.NodeFormatter , len (stmts ))
34
33
for si , stmt := range stmts {
35
34
asts [si ] = stmt .AST
@@ -69,8 +68,8 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
69
68
walk (node .Left , node .Right )
70
69
case * tree.CaseExpr :
71
70
walk (node .Expr , node .Else )
72
- for _ , w := range node .Whens {
73
- walk (w .Cond , w .Val )
71
+ for _ , when := range node .Whens {
72
+ walk (when .Cond , when .Val )
74
73
}
75
74
case * tree.RangeCond :
76
75
walk (node .Left , node .From , node .To )
@@ -98,6 +97,11 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
98
97
walk (expr )
99
98
}
100
99
case * tree.FamilyTableDef :
100
+ case * tree.From :
101
+ walk (node .AsOf )
102
+ for _ , table := range node .Tables {
103
+ walk (table )
104
+ }
101
105
case * tree.FuncExpr :
102
106
if node .WindowDef != nil {
103
107
walk (node .WindowDef )
@@ -111,6 +115,12 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
111
115
case * tree.NumVal :
112
116
case * tree.OnJoinCond :
113
117
walk (node .Expr )
118
+ case * tree.Order :
119
+ walk (node .Expr , node .Table )
120
+ case tree.OrderBy :
121
+ for _ , order := range node {
122
+ walk (order )
123
+ }
114
124
case * tree.OrExpr :
115
125
walk (node .Left , node .Right )
116
126
case * tree.ParenExpr :
@@ -126,16 +136,12 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
126
136
walk (node .With )
127
137
}
128
138
if node .OrderBy != nil {
129
- for _ , order := range node .OrderBy {
130
- walk (order )
131
- }
139
+ walk (node .OrderBy )
132
140
}
133
141
if node .Limit != nil {
134
142
walk (node .Limit )
135
143
}
136
144
walk (node .Select )
137
- case * tree.Order :
138
- walk (node .Expr , node .Table )
139
145
case * tree.Limit :
140
146
walk (node .Count )
141
147
case * tree.SelectClause :
@@ -156,10 +162,7 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
156
162
walk (group )
157
163
}
158
164
}
159
- walk (node .From .AsOf )
160
- for _ , table := range node .From .Tables {
161
- walk (table )
162
- }
165
+ walk (& node .From )
163
166
case tree.SelectExpr :
164
167
walk (node .Expr )
165
168
case tree.SelectExprs :
@@ -173,6 +176,10 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
173
176
case * tree.StrVal :
174
177
case * tree.Subquery :
175
178
walk (node .Select )
179
+ case tree.TableExprs :
180
+ for _ , expr := range node {
181
+ walk (expr )
182
+ }
176
183
case * tree.TableName , tree.TableName :
177
184
case * tree.Tuple :
178
185
for _ , expr := range node .Exprs {
@@ -214,8 +221,8 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
214
221
walk (expr )
215
222
}
216
223
default :
217
- if w .unknownNodes != nil {
218
- w .unknownNodes = append (w .unknownNodes , node )
224
+ if w .UnknownNodes != nil {
225
+ w .UnknownNodes = append (w .UnknownNodes , node )
219
226
}
220
227
}
221
228
}
@@ -270,8 +277,27 @@ func ColNamesInSelect(sql string) (referredCols ReferredCols, err error) {
270
277
if err != nil {
271
278
return
272
279
}
273
- for _ , col := range w .unknownNodes {
280
+ for _ , col := range w .UnknownNodes {
274
281
log .Printf ("unhandled column type %T" , col )
275
282
}
276
283
return
277
284
}
285
+
286
+ func AllColsContained (set ReferredCols , cols []string ) bool {
287
+ if cols == nil {
288
+ if set == nil {
289
+ return true
290
+ } else {
291
+ return false
292
+ }
293
+ }
294
+ if len (set ) != len (cols ) {
295
+ return false
296
+ }
297
+ for _ , col := range cols {
298
+ if _ , exist := set [col ]; ! exist {
299
+ return false
300
+ }
301
+ }
302
+ return true
303
+ }
0 commit comments