@@ -6,40 +6,46 @@ import (
6
6
"slices"
7
7
8
8
"github.com/nais/api/internal/graph/model"
9
+ "github.com/sirupsen/logrus"
9
10
"github.com/sourcegraph/conc/pool"
10
11
)
11
12
12
13
// SortFunc compares two values of type V and returns an integer indicating their order.
13
14
// If a < b, the function should return a negative value.
14
15
// If a == b, the function should return 0.
15
16
// If a > b, the function should return a positive value.
16
- type SortFunc [V any ] func (ctx context.Context , a , b V ) int
17
+ type SortFunc [T any ] func (ctx context.Context , a , b T ) int
17
18
18
19
// ConcurrentSortFunc should return an integer indicating the order of the given value.
19
20
// The results will later be sorted by the returned value.
20
- type ConcurrentSortFunc [V any ] func (ctx context.Context , a V ) int
21
+ type ConcurrentSortFunc [T any ] func (ctx context.Context , a T ) int
21
22
22
23
// Filter is a function that returns true if the given value should be included in the result.
23
- type Filter [V any , FilterObj any ] func (ctx context.Context , v V , filter FilterObj ) bool
24
+ type Filter [T any , FilterObj any ] func (ctx context.Context , v T , filter FilterObj ) bool
25
+
26
+ // TieBreaker is a combination of a SortField and a direction that might be able to resolve equal fields during sorting.
27
+ // If the direction is not supplied, the direction used for the original sort will be used. The referenced field must be
28
+ // registered with RegisterSort (concurrent tie-break sorters are not supported).
29
+ type TieBreaker [SortField comparable ] struct {
30
+ Field SortField
31
+ Direction * model.OrderDirection
32
+ }
24
33
25
- type funcs [V any ] struct {
26
- concurrentSort ConcurrentSortFunc [V ]
27
- sort SortFunc [V ]
34
+ type funcs [T any , SortField comparable ] struct {
35
+ concurrentSort ConcurrentSortFunc [T ]
36
+ sort SortFunc [T ]
37
+ tieBreakers []TieBreaker [SortField ]
28
38
}
29
39
30
- type SortFilter [V any , SortField comparable , FilterObj comparable ] struct {
31
- sorters map [SortField ]funcs [V ]
32
- filters []Filter [V , FilterObj ]
33
- tieBreakSortField SortField
40
+ type SortFilter [T any , SortField comparable , FilterObj comparable ] struct {
41
+ sorters map [SortField ]funcs [T , SortField ]
42
+ filters []Filter [T , FilterObj ]
34
43
}
35
44
36
- // New creates a new SortFilter with the given tieBreakSortField.
37
- // The tieBreakSortField is used when two values are equal in the Sort function, and will use the direction supplied
38
- // when calling Sort. The tieBreakSortField must not be registered as a ConcurrentSort.
39
- func New [V any , SortField comparable , FilterObj comparable ](tieBreakSortField SortField ) * SortFilter [V , SortField , FilterObj ] {
40
- return & SortFilter [V , SortField , FilterObj ]{
41
- sorters : make (map [SortField ]funcs [V ]),
42
- tieBreakSortField : tieBreakSortField ,
45
+ // New creates a new SortFilter
46
+ func New [T any , SortField comparable , FilterObj comparable ]() * SortFilter [T , SortField , FilterObj ] {
47
+ return & SortFilter [T , SortField , FilterObj ]{
48
+ sorters : make (map [SortField ]funcs [T , SortField ]),
43
49
}
44
50
}
45
51
@@ -49,25 +55,29 @@ func (s *SortFilter[T, SortField, FilterObj]) SupportsSort(field SortField) bool
49
55
return exists
50
56
}
51
57
52
- func (s * SortFilter [T , SortField , FilterObj ]) RegisterSort (field SortField , sort SortFunc [T ]) {
58
+ // RegisterSort will add support for sorting on a specific field. Optional tie-breakers can be supplied to resolve equal
59
+ // values, and will be executed in the given order.
60
+ func (s * SortFilter [T , SortField , FilterObj ]) RegisterSort (field SortField , sort SortFunc [T ], tieBreakers ... TieBreaker [SortField ]) {
53
61
if _ , ok := s .sorters [field ]; ok {
54
62
panic (fmt .Sprintf ("sort field is already registered: %v" , field ))
55
63
}
56
64
57
- s .sorters [field ] = funcs [T ]{
58
- sort : sort ,
65
+ s .sorters [field ] = funcs [T , SortField ]{
66
+ sort : sort ,
67
+ tieBreakers : tieBreakers ,
59
68
}
60
69
}
61
70
62
- func (s * SortFilter [T , SortField , FilterObj ]) RegisterConcurrentSort (field SortField , sort ConcurrentSortFunc [T ]) {
71
+ // RegisterConcurrentSort will add support for doing concurrent sorting on a specific field. Optional tie-breakers can
72
+ // be supplied to resolve equal values, and will be executed in the given order.
73
+ func (s * SortFilter [T , SortField , FilterObj ]) RegisterConcurrentSort (field SortField , sort ConcurrentSortFunc [T ], tieBreakers ... TieBreaker [SortField ]) {
63
74
if _ , ok := s .sorters [field ]; ok {
64
75
panic (fmt .Sprintf ("sort field is already registered: %v" , field ))
65
- } else if field == s .tieBreakSortField {
66
- panic (fmt .Sprintf ("sort field is used for tie break and can not be concurrent: %v" , field ))
67
76
}
68
77
69
- s .sorters [field ] = funcs [T ]{
78
+ s .sorters [field ] = funcs [T , SortField ]{
70
79
concurrentSort : sort ,
80
+ tieBreakers : tieBreakers ,
71
81
}
72
82
}
73
83
@@ -129,14 +139,14 @@ func (s *SortFilter[T, SortField, FilterObj]) Sort(ctx context.Context, items []
129
139
}
130
140
131
141
if sorter .concurrentSort != nil {
132
- s .sortConcurrent (ctx , items , sorter .concurrentSort , direction )
142
+ s .sortConcurrent (ctx , items , sorter .concurrentSort , field , direction , sorter . tieBreakers ... )
133
143
return
134
144
}
135
145
136
- s .sort (ctx , items , sorter .sort , direction )
146
+ s .sort (ctx , items , sorter .sort , field , direction , sorter . tieBreakers ... )
137
147
}
138
148
139
- func (s * SortFilter [T , SortField , FilterObj ]) sortConcurrent (ctx context.Context , items []T , sort ConcurrentSortFunc [T ], direction model.OrderDirection ) {
149
+ func (s * SortFilter [T , SortField , FilterObj ]) sortConcurrent (ctx context.Context , items []T , sort ConcurrentSortFunc [T ], field SortField , direction model.OrderDirection , tieBreakers ... TieBreaker [ SortField ] ) {
140
150
type sortable struct {
141
151
item T
142
152
key int
@@ -160,7 +170,7 @@ func (s *SortFilter[T, SortField, FilterObj]) sortConcurrent(ctx context.Context
160
170
161
171
slices .SortStableFunc (res , func (a , b sortable ) int {
162
172
if b .key == a .key {
163
- return s .tieBreak (ctx , a .item , b .item , direction )
173
+ return s .tieBreak (ctx , a .item , b .item , field , direction , tieBreakers ... )
164
174
}
165
175
166
176
if direction == model .OrderDirectionDesc {
@@ -174,7 +184,7 @@ func (s *SortFilter[T, SortField, FilterObj]) sortConcurrent(ctx context.Context
174
184
}
175
185
}
176
186
177
- func (s * SortFilter [T , SortField , FilterObj ]) sort (ctx context.Context , items []T , sort SortFunc [T ], direction model.OrderDirection ) {
187
+ func (s * SortFilter [T , SortField , FilterObj ]) sort (ctx context.Context , items []T , sort SortFunc [T ], field SortField , direction model.OrderDirection , tieBreakers ... TieBreaker [ SortField ] ) {
178
188
slices .SortStableFunc (items , func (a , b T ) int {
179
189
var ret int
180
190
if direction == model .OrderDirectionDesc {
@@ -184,16 +194,54 @@ func (s *SortFilter[T, SortField, FilterObj]) sort(ctx context.Context, items []
184
194
}
185
195
186
196
if ret == 0 {
187
- return s .tieBreak (ctx , a , b , direction )
197
+ return s .tieBreak (ctx , a , b , field , direction , tieBreakers ... )
188
198
}
189
199
return ret
190
200
})
191
201
}
192
202
193
- func (s * SortFilter [T , SortField , FilterObj ]) tieBreak (ctx context.Context , a , b T , direction model.OrderDirection ) int {
194
- if direction == model .OrderDirectionDesc {
195
- return s .sorters [s .tieBreakSortField ].sort (ctx , b , a )
203
+ // tieBreak will resolve equal fields after the initial sort by using the supplied tie-breakers. The function will
204
+ // return as soon as a tie-breaker returns a non-zero value.
205
+ func (s * SortFilter [T , SortField , FilterObj ]) tieBreak (ctx context.Context , a , b T , field SortField , direction model.OrderDirection , tieBreakers ... TieBreaker [SortField ]) int {
206
+ for _ , tb := range tieBreakers {
207
+ dir := direction
208
+ if tb .Direction != nil {
209
+ dir = * tb .Direction
210
+ }
211
+
212
+ sorter , ok := s .sorters [tb .Field ]
213
+ if ! ok {
214
+ logrus .WithFields (logrus.Fields {
215
+ "field_type" : fmt .Sprintf ("%T" , field ),
216
+ "tie_breaker" : tb .Field ,
217
+ }).Errorf ("no sort registered for tie-breaker" )
218
+ continue
219
+ } else if sorter .sort == nil {
220
+ logrus .WithFields (logrus.Fields {
221
+ "field_type" : fmt .Sprintf ("%T" , field ),
222
+ "tie_breaker" : tb .Field ,
223
+ }).Errorf ("tie-breaker can not be a concurrent sort" )
224
+ continue
225
+ }
226
+
227
+ var v int
228
+ if dir == model .OrderDirectionDesc {
229
+ v = sorter .sort (ctx , b , a )
230
+ } else {
231
+ v = sorter .sort (ctx , a , b )
232
+ }
233
+
234
+ if v != 0 {
235
+ return v
236
+ }
196
237
}
197
238
198
- return s .sorters [s .tieBreakSortField ].sort (ctx , a , b )
239
+ logrus .
240
+ WithFields (logrus.Fields {
241
+ "field_type" : fmt .Sprintf ("%T" , field ),
242
+ "sort_field" : field ,
243
+ "tie_breakers" : tieBreakers ,
244
+ }).
245
+ Errorf ("unable to tie-break sort, gotta have more tie-breakers" )
246
+ return 0
199
247
}
0 commit comments