|
4 | 4 | package trie
|
5 | 5 |
|
6 | 6 | import (
|
7 |
| - "context" |
| 7 | + "cmp" |
| 8 | + "fmt" |
| 9 | + "iter" |
| 10 | + "slices" |
| 11 | + "sort" |
| 12 | + "strings" |
8 | 13 | "sync"
|
9 | 14 | )
|
10 | 15 |
|
11 |
| -// Trie is a simple trie that accepts arbitrary Key types as its input. |
12 |
| -type Trie struct { |
13 |
| - children map[interface{}]*Trie |
14 |
| - hasValue bool |
15 |
| - mu sync.RWMutex |
16 |
| - label Label |
17 |
| - value interface{} |
| 16 | +// Tokenizer is an object that tokenize a L into individual keys. |
| 17 | +// For example, a string tokenizer would split a string into individual runes. |
| 18 | +type Tokenizer[L any, K cmp.Ordered] interface { |
| 19 | + Tokenize(L) (iter.Seq[K], error) |
| 20 | +} |
| 21 | + |
| 22 | +// TokenizeFunc is a function that implements the Tokenizer interface |
| 23 | +type TokenizeFunc[L any, K cmp.Ordered] func(L) (iter.Seq[K], error) |
| 24 | + |
| 25 | +func (f TokenizeFunc[L, K]) Tokenize(in L) (iter.Seq[K], error) { |
| 26 | + return f(in) |
18 | 27 | }
|
19 | 28 |
|
20 |
| -// New creates a new Trie |
21 |
| -func New() *Trie { |
22 |
| - return newTrie(nil) |
| 29 | +// Trie is a trie that accepts arbitrary Key types as its input. |
| 30 | +// |
| 31 | +// L represents the "label", the input that is used to Get/Set/Delete |
| 32 | +// a value from the trie. |
| 33 | +// |
| 34 | +// K represents the "key", the individual components that are associated |
| 35 | +// with the nodes in the trie. |
| 36 | +// |
| 37 | +// V represents the "value", the data that is stored in the trie. |
| 38 | +// Data is stored at the leaf nodes of the trie. |
| 39 | +type Trie[L any, K cmp.Ordered, V any] struct { |
| 40 | + mu sync.RWMutex |
| 41 | + root *node[K, V] |
| 42 | + tokenizer Tokenizer[L, K] |
23 | 43 | }
|
24 | 44 |
|
25 |
| -func newTrie(l Label) *Trie { |
26 |
| - return &Trie{ |
27 |
| - label: l, |
28 |
| - children: make(map[interface{}]*Trie), |
| 45 | +// Node represents an individual node in the trie. |
| 46 | +type Node[K cmp.Ordered, V any] interface { |
| 47 | + Key() K |
| 48 | + Value() V |
| 49 | + Children() iter.Seq[Node[K, V]] |
| 50 | + AddChild(Node[K, V]) |
| 51 | +} |
| 52 | + |
| 53 | +// New creates a new Trie object. |
| 54 | +func New[L any, K cmp.Ordered, V any](tokenizer Tokenizer[L, K]) *Trie[L, K, V] { |
| 55 | + return &Trie[L, K, V]{ |
| 56 | + root: newNode[K, V](), |
| 57 | + tokenizer: tokenizer, |
29 | 58 | }
|
30 | 59 | }
|
31 | 60 |
|
32 | 61 | // Get returns the value associated with `key`. The second return value
|
33 |
| -// is true if the value exists, false otherwise |
34 |
| -func (t *Trie) Get(key Key) (interface{}, bool) { |
| 62 | +// indicates if the value was found. |
| 63 | +func (t *Trie[L, K, V]) Get(key L) (V, bool) { |
| 64 | + var zero V |
| 65 | + iter, err := t.tokenizer.Tokenize(key) |
| 66 | + if err != nil { |
| 67 | + return zero, false |
| 68 | + } |
| 69 | + |
35 | 70 | t.mu.RLock()
|
36 | 71 | defer t.mu.RUnlock()
|
37 |
| - node := t |
38 |
| - for iter := key.Labels(); iter.Next(); { |
39 |
| - l := iter.Label() |
40 |
| - node = node.children[l.UniqueID()] |
41 |
| - if node == nil { |
42 |
| - return nil, false |
| 72 | + var tokens []K |
| 73 | + for x := range iter { |
| 74 | + tokens = append(tokens, x) |
| 75 | + } |
| 76 | + return get(t.root, tokens) |
| 77 | +} |
| 78 | + |
| 79 | +func get[K cmp.Ordered, V any](root Node[K, V], tokens []K) (V, bool) { |
| 80 | + if len(tokens) > 0 { |
| 81 | + for child := range root.Children() { |
| 82 | + if child.Key() == tokens[0] { |
| 83 | + // found the current token in the children. |
| 84 | + if len(tokens) == 1 { |
| 85 | + // this is the node we're looking for |
| 86 | + return child.Value(), true |
| 87 | + } |
| 88 | + // we need to traverse down the trie |
| 89 | + return get[K, V](child, tokens[1:]) |
| 90 | + } |
43 | 91 | }
|
44 | 92 | }
|
45 |
| - return node.value, true |
| 93 | + |
| 94 | + // if we got here, that means we couldn't find a common ancestor |
| 95 | + var zero V |
| 96 | + return zero, false |
46 | 97 | }
|
47 | 98 |
|
48 |
| -// Put sets `key` to point to data `value`. The return value is true |
49 |
| -// if the value was set anew. If this was an update operation, the return |
50 |
| -// value would be false |
51 |
| -func (t *Trie) Put(key Key, value interface{}) bool { |
| 99 | +// Delete removes data associated with `key`. It returns true if the value |
| 100 | +// was found and deleted, false otherwise |
| 101 | +func (t *Trie[L, K, V]) Delete(key L) bool { |
| 102 | + iter, err := t.tokenizer.Tokenize(key) |
| 103 | + if err != nil { |
| 104 | + return false |
| 105 | + } |
| 106 | + var tokens []K |
| 107 | + for x := range iter { |
| 108 | + tokens = append(tokens, x) |
| 109 | + } |
| 110 | + |
52 | 111 | t.mu.Lock()
|
53 | 112 | defer t.mu.Unlock()
|
| 113 | + return delete[K, V](t.root, tokens) |
| 114 | +} |
| 115 | + |
| 116 | +func delete[K cmp.Ordered, V any](root *node[K, V], tokens []K) bool { |
| 117 | + if len(tokens) <= 0 { |
| 118 | + return false |
| 119 | + } |
| 120 | + |
| 121 | + for i, child := range root.children { |
| 122 | + if child.Key() == tokens[0] { |
| 123 | + if len(tokens) == 1 { |
| 124 | + // this is the node we're looking for |
| 125 | + root.children = slices.Delete(root.children, i, i+1) |
| 126 | + return true |
| 127 | + } |
54 | 128 |
|
55 |
| - node := t |
56 |
| - for iter := key.Labels(); iter.Next(); { |
57 |
| - l := iter.Label() |
58 |
| - child := node.children[l.UniqueID()] |
59 |
| - if child == nil { |
60 |
| - child = newTrie(l) |
61 |
| - node.children[l.UniqueID()] = child |
| 129 | + // we need to traverse down the trie |
| 130 | + if delete[K, V](child, tokens[1:]) { |
| 131 | + if len(child.children) == 0 { |
| 132 | + root.children = slices.Delete(root.children, i, i+1) |
| 133 | + } |
| 134 | + return true |
| 135 | + } |
| 136 | + return false |
62 | 137 | }
|
63 |
| - node = child |
64 | 138 | }
|
65 | 139 |
|
66 |
| - isNewVal := node.hasValue |
67 |
| - node.hasValue = true |
68 |
| - node.value = value |
69 |
| - return isNewVal |
| 140 | + return false |
70 | 141 | }
|
71 | 142 |
|
72 |
| -func (t *Trie) isLeaf() bool { |
73 |
| - return len(t.children) == 0 |
74 |
| -} |
| 143 | +// Put sets `key` to point to data `value`. |
| 144 | +func (t *Trie[L, K, V]) Put(key L, value V) error { |
| 145 | + iter, err := t.tokenizer.Tokenize(key) |
| 146 | + if err != nil { |
| 147 | + return fmt.Errorf(`failed to tokenize key: %w`, err) |
| 148 | + } |
| 149 | + node := t.root |
| 150 | + |
| 151 | + var tokens []K |
| 152 | + for x := range iter { |
| 153 | + tokens = append(tokens, x) |
| 154 | + } |
75 | 155 |
|
76 |
| -type ancestor struct { |
77 |
| - Label Label |
78 |
| - Node *Trie |
| 156 | + t.mu.Lock() |
| 157 | + defer t.mu.Unlock() |
| 158 | + put[K, V](node, tokens, value) |
| 159 | + return nil |
79 | 160 | }
|
80 | 161 |
|
81 |
| -// Delete removes data associated with `key`. It returns true if the value |
82 |
| -// was found and deleted, false otherwise |
83 |
| -func (t *Trie) Delete(key Key) bool { |
84 |
| - var ancestors []ancestor |
85 |
| - node := t |
86 |
| - for iter := key.Labels(); iter.Next(); { |
87 |
| - l := iter.Label() |
88 |
| - ancestors = append(ancestors, ancestor{ |
89 |
| - Label: l, |
90 |
| - Node: node, |
91 |
| - }) |
92 |
| - node = node.children[l.UniqueID()] |
93 |
| - if node == nil { |
94 |
| - // node does not exist |
95 |
| - return false |
| 162 | +func put[K cmp.Ordered, V any](root Node[K, V], tokens []K, value V) { |
| 163 | + if len(tokens) == 0 { |
| 164 | + return |
| 165 | + } |
| 166 | + |
| 167 | + for _, token := range tokens { |
| 168 | + for child := range root.Children() { |
| 169 | + if child.Key() == token { |
| 170 | + // found the current token in the children. |
| 171 | + // we need to traverse down the trie |
| 172 | + put[K, V](child, tokens[1:], value) |
| 173 | + return |
| 174 | + } |
96 | 175 | }
|
97 | 176 | }
|
98 | 177 |
|
99 |
| - // delete the node value |
100 |
| - node.value = nil |
| 178 | + // if we got here, that means we couldn't find a common ancestor |
101 | 179 |
|
102 |
| - // if leaf, remove it from its parent's children map. Repeat for ancestors. |
103 |
| - if !node.isLeaf() { |
104 |
| - return true |
| 180 | + // the first token has already been consumed, create a new node, |
| 181 | + var newRoot *node[K, V] |
| 182 | + var cur *node[K, V] |
| 183 | + for _, token := range tokens { // duplicate token? |
| 184 | + newNode := newNode[K, V]() |
| 185 | + newNode.key = token |
| 186 | + if cur == nil { |
| 187 | + newRoot = newNode |
| 188 | + } else { |
| 189 | + cur.children = append(cur.children, newNode) |
| 190 | + } |
| 191 | + cur = newNode |
105 | 192 | }
|
106 |
| - // iterate backwards over the ancestors |
107 |
| - for i := len(ancestors) - 1; i >= 0; i-- { |
108 |
| - ancestor := ancestors[i] |
109 |
| - parent := ancestor.Node |
110 |
| - delete(parent.children, ancestor.Label.UniqueID()) |
| 193 | + // cur holds the last element. |
| 194 | + cur.value = value |
111 | 195 |
|
112 |
| - if !parent.isLeaf() { |
113 |
| - // parent has other children, stop |
114 |
| - break |
115 |
| - } |
116 |
| - parent.children = nil |
117 |
| - if parent.hasValue { |
118 |
| - // parent has a value, stop |
119 |
| - break |
| 196 | + root.AddChild(newRoot) |
| 197 | +} |
| 198 | + |
| 199 | +type node[K cmp.Ordered, V any] struct { |
| 200 | + mu sync.RWMutex |
| 201 | + key K |
| 202 | + value V |
| 203 | + children []*node[K, V] |
| 204 | +} |
| 205 | + |
| 206 | +func newNode[K cmp.Ordered, V any]() *node[K, V] { |
| 207 | + return &node[K, V]{} |
| 208 | +} |
| 209 | + |
| 210 | +func (n *node[K, V]) Key() K { |
| 211 | + return n.key |
| 212 | +} |
| 213 | + |
| 214 | +func (n *node[K, V]) Value() V { |
| 215 | + return n.value |
| 216 | +} |
| 217 | + |
| 218 | +func (n *node[K, V]) Children() iter.Seq[Node[K, V]] { |
| 219 | + n.mu.RLock() |
| 220 | + children := make([]*node[K, V], len(n.children)) |
| 221 | + copy(children, n.children) |
| 222 | + n.mu.RUnlock() |
| 223 | + return func(yield func(Node[K, V]) bool) { |
| 224 | + for _, child := range children { |
| 225 | + if !yield(child) { |
| 226 | + break |
| 227 | + } |
120 | 228 | }
|
121 | 229 | }
|
122 |
| - return true |
123 | 230 | }
|
124 | 231 |
|
125 |
| -// WalkPair is what you get when you call `Walk()` on a trie. |
126 |
| -type WalkPair struct { |
127 |
| - // Because we have a generic "Label" type, we unfortunately cannot |
128 |
| - // provide a re-constructed Key object for the user to handle. |
129 |
| - // Instead we provide this value as a slice of Labels |
130 |
| - Labels []Label |
| 232 | +func (n *node[K, V]) AddChild(child Node[K, V]) { |
| 233 | + n.mu.Lock() |
| 234 | + // This is kind of gross, but we're only covering *node[T] with |
| 235 | + // Node[T] interface because we don't want the users to instantiate |
| 236 | + // their own nodes... so this type conversion is safe. |
| 237 | + //nolint:forcetypeassert |
| 238 | + n.children = append(n.children, child.(*node[K, V])) |
| 239 | + sort.Slice(n.children, func(i, j int) bool { |
| 240 | + return n.children[i].Key() < n.children[j].Key() |
| 241 | + }) |
| 242 | + n.mu.Unlock() |
| 243 | +} |
131 | 244 |
|
132 |
| - // Value is the value associated with the Labels |
133 |
| - Value interface{} |
| 245 | +type VisitMetadata struct { |
| 246 | + Depth int |
134 | 247 | }
|
135 | 248 |
|
136 |
| -// Walk returns a channel that you can read from to access all data |
137 |
| -// that is stored within this trie. |
138 |
| -func (t *Trie) Walk(ctx context.Context) <-chan WalkPair { |
139 |
| - ch := make(chan WalkPair) |
140 |
| - go t.walk(ctx, ch, nil) |
141 |
| - return ch |
| 249 | +type Visitor[K cmp.Ordered, V any] interface { |
| 250 | + Visit(Node[K, V], VisitMetadata) bool |
142 | 251 | }
|
143 | 252 |
|
144 |
| -func (t *Trie) walk(ctx context.Context, dst chan WalkPair, labels []Label) { |
145 |
| - if labels == nil { |
146 |
| - t.mu.RLock() |
147 |
| - defer t.mu.RUnlock() |
148 |
| - defer close(dst) |
149 |
| - } |
| 253 | +func Walk[L any, K cmp.Ordered, V any](trie *Trie[L, K, V], v Visitor[K, V]) { |
| 254 | + var meta VisitMetadata |
| 255 | + meta.Depth = 1 |
| 256 | + walk(trie.root, v, meta) |
| 257 | +} |
150 | 258 |
|
151 |
| - if t.hasValue { |
152 |
| - p := WalkPair{ |
153 |
| - Labels: labels, |
154 |
| - Value: t.value, |
155 |
| - } |
156 |
| - select { |
157 |
| - case <-ctx.Done(): |
158 |
| - return |
159 |
| - case dst <- p: |
| 259 | +func walk[K cmp.Ordered, V any](node Node[K, V], v Visitor[K, V], meta VisitMetadata) { |
| 260 | + for child := range node.Children() { |
| 261 | + if !v.Visit(child, meta) { |
| 262 | + break |
160 | 263 | }
|
| 264 | + walk(child, v, VisitMetadata{Depth: meta.Depth + 1}) |
161 | 265 | }
|
| 266 | +} |
162 | 267 |
|
163 |
| - for _, child := range t.children { |
164 |
| - child.walk(ctx, dst, append(labels, child.label)) |
| 268 | +type dumper[K cmp.Ordered, V any] struct{} |
| 269 | + |
| 270 | +func (dumper[K, V]) Visit(n Node[K, V], meta VisitMetadata) bool { |
| 271 | + var sb strings.Builder |
| 272 | + for i := 0; i < meta.Depth; i++ { |
| 273 | + sb.WriteString(" ") |
165 | 274 | }
|
| 275 | + |
| 276 | + fmt.Fprintf(&sb, "%v: %v", n.Key(), n.Value()) |
| 277 | + fmt.Println(sb.String()) |
| 278 | + return true |
| 279 | +} |
| 280 | + |
| 281 | +func Dumper[K cmp.Ordered, V any]() Visitor[K, V] { |
| 282 | + return dumper[K, V]{} |
166 | 283 | }
|
0 commit comments