Skip to content

Commit 16b5a7b

Browse files
author
Daisuke Maki
committedOct 11, 2024·
Initial v2 code
1 parent fe02a1e commit 16b5a7b

9 files changed

+292
-276
lines changed
 

‎README.md

-29
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,3 @@ This trie is implemented such that generic Key types can be used.
44
Most other trie implementations are optimized for string based keys, but my use
55
case is to match certain numeric opcodes to arbitrary data.
66

7-
Within this library Keys are treated as sequence of Labels.
8-
For example, a string can be thought of as Key that is comprised of a sequence
9-
of rune Labels.
10-
11-
Each Key need to be able to break down to Labels via the `Iterate` method.
12-
Each Label in turn becomes the local key in a trie node.
13-
Each Label need to implement a `UniqueID` method to identify itself.
14-
15-
# SYNOPSIS
16-
17-
```go
18-
ctx, cancel := context.WithCancel(context.Background())
19-
defer cancel()
20-
21-
t := trie.New()
22-
23-
t.Put(ctx, trie.StringKey("foo"), 1)
24-
v, ok := t.Get(ctx, trie.StringKey("foo"))
25-
ok := t.Delete(ctx, trie.StringKey("foo"))
26-
for p := range t.Walk(ctx) {
27-
// p.Labels
28-
// p.Value
29-
}
30-
```
31-
32-
# REFERENCES
33-
34-
Originally based on https://github.com/koron/go-trie
35-
Much code stolen from https://github.com/dghubble/trie

‎go.mod

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1-
module github.com/lestrrat-go/trie
1+
module github.com/lestrrat-go/trie/v2
22

3-
go 1.16
3+
go 1.23
44

5-
require github.com/stretchr/testify v1.7.0
5+
require github.com/stretchr/testify v1.9.0
6+
7+
require (
8+
github.com/davecgh/go-spew v1.1.1 // indirect
9+
github.com/pmezard/go-difflib v1.0.0 // indirect
10+
gopkg.in/yaml.v3 v3.0.1 // indirect
11+
)

‎go.sum

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
2-
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
1+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
2+
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
33
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
44
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
5-
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
6-
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
7-
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
5+
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
6+
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
87
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
98
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
10-
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
11-
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
9+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
10+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

‎interface.go

-27
This file was deleted.

‎string.go

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package trie
2+
3+
import (
4+
"iter"
5+
)
6+
7+
// String returns a Tokenizer that tokenizes a string into individual runes.
8+
func String() Tokenizer[string, rune] {
9+
return TokenizeFunc[string, rune](func(s string) (iter.Seq[rune], error) {
10+
return func(yield func(rune) bool) {
11+
for _, r := range s {
12+
if !yield(r) {
13+
break
14+
}
15+
}
16+
}, nil
17+
})
18+
}

‎string_key.go

-34
This file was deleted.

‎trie.go

+228-111
Original file line numberDiff line numberDiff line change
@@ -4,163 +4,280 @@
44
package trie
55

66
import (
7-
"context"
7+
"cmp"
8+
"fmt"
9+
"iter"
10+
"slices"
11+
"sort"
12+
"strings"
813
"sync"
914
)
1015

11-
// Trie is a simple trie that accepts arbitrary Key types as its input.
12-
type Trie struct {
13-
children map[interface{}]*Trie
14-
hasValue bool
15-
mu sync.RWMutex
16-
label Label
17-
value interface{}
16+
// Tokenizer is an object that tokenize a L into individual keys.
17+
// For example, a string tokenizer would split a string into individual runes.
18+
type Tokenizer[L any, K cmp.Ordered] interface {
19+
Tokenize(L) (iter.Seq[K], error)
20+
}
21+
22+
// TokenizeFunc is a function that implements the Tokenizer interface
23+
type TokenizeFunc[L any, K cmp.Ordered] func(L) (iter.Seq[K], error)
24+
25+
func (f TokenizeFunc[L, K]) Tokenize(in L) (iter.Seq[K], error) {
26+
return f(in)
1827
}
1928

20-
// New creates a new Trie
21-
func New() *Trie {
22-
return newTrie(nil)
29+
// Trie is a trie that accepts arbitrary Key types as its input.
30+
//
31+
// L represents the "label", the input that is used to Get/Set/Delete
32+
// a value from the trie.
33+
//
34+
// K represents the "key", the individual components that are associated
35+
// with the nodes in the trie.
36+
//
37+
// V represents the "value", the data that is stored in the trie.
38+
// Data is stored at the leaf nodes of the trie.
39+
type Trie[L any, K cmp.Ordered, V any] struct {
40+
mu sync.RWMutex
41+
root *node[K, V]
42+
tokenizer Tokenizer[L, K]
2343
}
2444

25-
func newTrie(l Label) *Trie {
26-
return &Trie{
27-
label: l,
28-
children: make(map[interface{}]*Trie),
45+
// Node represents an individual node in the trie.
46+
type Node[K cmp.Ordered, V any] interface {
47+
Key() K
48+
Value() V
49+
Children() iter.Seq[Node[K, V]]
50+
AddChild(Node[K, V])
51+
}
52+
53+
// New creates a new Trie object.
54+
func New[L any, K cmp.Ordered, V any](tokenizer Tokenizer[L, K]) *Trie[L, K, V] {
55+
return &Trie[L, K, V]{
56+
root: newNode[K, V](),
57+
tokenizer: tokenizer,
2958
}
3059
}
3160

3261
// Get returns the value associated with `key`. The second return value
33-
// is true if the value exists, false otherwise
34-
func (t *Trie) Get(key Key) (interface{}, bool) {
62+
// indicates if the value was found.
63+
func (t *Trie[L, K, V]) Get(key L) (V, bool) {
64+
var zero V
65+
iter, err := t.tokenizer.Tokenize(key)
66+
if err != nil {
67+
return zero, false
68+
}
69+
3570
t.mu.RLock()
3671
defer t.mu.RUnlock()
37-
node := t
38-
for iter := key.Labels(); iter.Next(); {
39-
l := iter.Label()
40-
node = node.children[l.UniqueID()]
41-
if node == nil {
42-
return nil, false
72+
var tokens []K
73+
for x := range iter {
74+
tokens = append(tokens, x)
75+
}
76+
return get(t.root, tokens)
77+
}
78+
79+
func get[K cmp.Ordered, V any](root Node[K, V], tokens []K) (V, bool) {
80+
if len(tokens) > 0 {
81+
for child := range root.Children() {
82+
if child.Key() == tokens[0] {
83+
// found the current token in the children.
84+
if len(tokens) == 1 {
85+
// this is the node we're looking for
86+
return child.Value(), true
87+
}
88+
// we need to traverse down the trie
89+
return get[K, V](child, tokens[1:])
90+
}
4391
}
4492
}
45-
return node.value, true
93+
94+
// if we got here, that means we couldn't find a common ancestor
95+
var zero V
96+
return zero, false
4697
}
4798

48-
// Put sets `key` to point to data `value`. The return value is true
49-
// if the value was set anew. If this was an update operation, the return
50-
// value would be false
51-
func (t *Trie) Put(key Key, value interface{}) bool {
99+
// Delete removes data associated with `key`. It returns true if the value
100+
// was found and deleted, false otherwise
101+
func (t *Trie[L, K, V]) Delete(key L) bool {
102+
iter, err := t.tokenizer.Tokenize(key)
103+
if err != nil {
104+
return false
105+
}
106+
var tokens []K
107+
for x := range iter {
108+
tokens = append(tokens, x)
109+
}
110+
52111
t.mu.Lock()
53112
defer t.mu.Unlock()
113+
return delete[K, V](t.root, tokens)
114+
}
115+
116+
func delete[K cmp.Ordered, V any](root *node[K, V], tokens []K) bool {
117+
if len(tokens) <= 0 {
118+
return false
119+
}
120+
121+
for i, child := range root.children {
122+
if child.Key() == tokens[0] {
123+
if len(tokens) == 1 {
124+
// this is the node we're looking for
125+
root.children = slices.Delete(root.children, i, i+1)
126+
return true
127+
}
54128

55-
node := t
56-
for iter := key.Labels(); iter.Next(); {
57-
l := iter.Label()
58-
child := node.children[l.UniqueID()]
59-
if child == nil {
60-
child = newTrie(l)
61-
node.children[l.UniqueID()] = child
129+
// we need to traverse down the trie
130+
if delete[K, V](child, tokens[1:]) {
131+
if len(child.children) == 0 {
132+
root.children = slices.Delete(root.children, i, i+1)
133+
}
134+
return true
135+
}
136+
return false
62137
}
63-
node = child
64138
}
65139

66-
isNewVal := node.hasValue
67-
node.hasValue = true
68-
node.value = value
69-
return isNewVal
140+
return false
70141
}
71142

72-
func (t *Trie) isLeaf() bool {
73-
return len(t.children) == 0
74-
}
143+
// Put sets `key` to point to data `value`.
144+
func (t *Trie[L, K, V]) Put(key L, value V) error {
145+
iter, err := t.tokenizer.Tokenize(key)
146+
if err != nil {
147+
return fmt.Errorf(`failed to tokenize key: %w`, err)
148+
}
149+
node := t.root
150+
151+
var tokens []K
152+
for x := range iter {
153+
tokens = append(tokens, x)
154+
}
75155

76-
type ancestor struct {
77-
Label Label
78-
Node *Trie
156+
t.mu.Lock()
157+
defer t.mu.Unlock()
158+
put[K, V](node, tokens, value)
159+
return nil
79160
}
80161

81-
// Delete removes data associated with `key`. It returns true if the value
82-
// was found and deleted, false otherwise
83-
func (t *Trie) Delete(key Key) bool {
84-
var ancestors []ancestor
85-
node := t
86-
for iter := key.Labels(); iter.Next(); {
87-
l := iter.Label()
88-
ancestors = append(ancestors, ancestor{
89-
Label: l,
90-
Node: node,
91-
})
92-
node = node.children[l.UniqueID()]
93-
if node == nil {
94-
// node does not exist
95-
return false
162+
func put[K cmp.Ordered, V any](root Node[K, V], tokens []K, value V) {
163+
if len(tokens) == 0 {
164+
return
165+
}
166+
167+
for _, token := range tokens {
168+
for child := range root.Children() {
169+
if child.Key() == token {
170+
// found the current token in the children.
171+
// we need to traverse down the trie
172+
put[K, V](child, tokens[1:], value)
173+
return
174+
}
96175
}
97176
}
98177

99-
// delete the node value
100-
node.value = nil
178+
// if we got here, that means we couldn't find a common ancestor
101179

102-
// if leaf, remove it from its parent's children map. Repeat for ancestors.
103-
if !node.isLeaf() {
104-
return true
180+
// the first token has already been consumed, create a new node,
181+
var newRoot *node[K, V]
182+
var cur *node[K, V]
183+
for _, token := range tokens { // duplicate token?
184+
newNode := newNode[K, V]()
185+
newNode.key = token
186+
if cur == nil {
187+
newRoot = newNode
188+
} else {
189+
cur.children = append(cur.children, newNode)
190+
}
191+
cur = newNode
105192
}
106-
// iterate backwards over the ancestors
107-
for i := len(ancestors) - 1; i >= 0; i-- {
108-
ancestor := ancestors[i]
109-
parent := ancestor.Node
110-
delete(parent.children, ancestor.Label.UniqueID())
193+
// cur holds the last element.
194+
cur.value = value
111195

112-
if !parent.isLeaf() {
113-
// parent has other children, stop
114-
break
115-
}
116-
parent.children = nil
117-
if parent.hasValue {
118-
// parent has a value, stop
119-
break
196+
root.AddChild(newRoot)
197+
}
198+
199+
type node[K cmp.Ordered, V any] struct {
200+
mu sync.RWMutex
201+
key K
202+
value V
203+
children []*node[K, V]
204+
}
205+
206+
func newNode[K cmp.Ordered, V any]() *node[K, V] {
207+
return &node[K, V]{}
208+
}
209+
210+
func (n *node[K, V]) Key() K {
211+
return n.key
212+
}
213+
214+
func (n *node[K, V]) Value() V {
215+
return n.value
216+
}
217+
218+
func (n *node[K, V]) Children() iter.Seq[Node[K, V]] {
219+
n.mu.RLock()
220+
children := make([]*node[K, V], len(n.children))
221+
copy(children, n.children)
222+
n.mu.RUnlock()
223+
return func(yield func(Node[K, V]) bool) {
224+
for _, child := range children {
225+
if !yield(child) {
226+
break
227+
}
120228
}
121229
}
122-
return true
123230
}
124231

125-
// WalkPair is what you get when you call `Walk()` on a trie.
126-
type WalkPair struct {
127-
// Because we have a generic "Label" type, we unfortunately cannot
128-
// provide a re-constructed Key object for the user to handle.
129-
// Instead we provide this value as a slice of Labels
130-
Labels []Label
232+
func (n *node[K, V]) AddChild(child Node[K, V]) {
233+
n.mu.Lock()
234+
// This is kind of gross, but we're only covering *node[T] with
235+
// Node[T] interface because we don't want the users to instantiate
236+
// their own nodes... so this type conversion is safe.
237+
//nolint:forcetypeassert
238+
n.children = append(n.children, child.(*node[K, V]))
239+
sort.Slice(n.children, func(i, j int) bool {
240+
return n.children[i].Key() < n.children[j].Key()
241+
})
242+
n.mu.Unlock()
243+
}
131244

132-
// Value is the value associated with the Labels
133-
Value interface{}
245+
type VisitMetadata struct {
246+
Depth int
134247
}
135248

136-
// Walk returns a channel that you can read from to access all data
137-
// that is stored within this trie.
138-
func (t *Trie) Walk(ctx context.Context) <-chan WalkPair {
139-
ch := make(chan WalkPair)
140-
go t.walk(ctx, ch, nil)
141-
return ch
249+
type Visitor[K cmp.Ordered, V any] interface {
250+
Visit(Node[K, V], VisitMetadata) bool
142251
}
143252

144-
func (t *Trie) walk(ctx context.Context, dst chan WalkPair, labels []Label) {
145-
if labels == nil {
146-
t.mu.RLock()
147-
defer t.mu.RUnlock()
148-
defer close(dst)
149-
}
253+
func Walk[L any, K cmp.Ordered, V any](trie *Trie[L, K, V], v Visitor[K, V]) {
254+
var meta VisitMetadata
255+
meta.Depth = 1
256+
walk(trie.root, v, meta)
257+
}
150258

151-
if t.hasValue {
152-
p := WalkPair{
153-
Labels: labels,
154-
Value: t.value,
155-
}
156-
select {
157-
case <-ctx.Done():
158-
return
159-
case dst <- p:
259+
func walk[K cmp.Ordered, V any](node Node[K, V], v Visitor[K, V], meta VisitMetadata) {
260+
for child := range node.Children() {
261+
if !v.Visit(child, meta) {
262+
break
160263
}
264+
walk(child, v, VisitMetadata{Depth: meta.Depth + 1})
161265
}
266+
}
162267

163-
for _, child := range t.children {
164-
child.walk(ctx, dst, append(labels, child.label))
268+
type dumper[K cmp.Ordered, V any] struct{}
269+
270+
func (dumper[K, V]) Visit(n Node[K, V], meta VisitMetadata) bool {
271+
var sb strings.Builder
272+
for i := 0; i < meta.Depth; i++ {
273+
sb.WriteString(" ")
165274
}
275+
276+
fmt.Fprintf(&sb, "%v: %v", n.Key(), n.Value())
277+
fmt.Println(sb.String())
278+
return true
279+
}
280+
281+
func Dumper[K cmp.Ordered, V any]() Visitor[K, V] {
282+
return dumper[K, V]{}
166283
}

‎trie_example_test.go

+10-19
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,41 @@
11
package trie_test
22

33
import (
4-
"context"
54
"fmt"
65

7-
"github.com/lestrrat-go/trie"
6+
"github.com/lestrrat-go/trie/v2"
87
)
98

10-
func ExampleStringKey() {
11-
tree := trie.New()
9+
func Example() {
10+
tree := trie.New[string, rune, any](trie.String())
1211

1312
// Put values in the trie
14-
tree.Put(trie.StringKey("foo"), "one")
15-
tree.Put(trie.StringKey("bar"), 2)
16-
tree.Put(trie.StringKey("baz"), 3.0)
17-
tree.Put(trie.StringKey("日本語"), []byte{'f', 'o', 'u', 'r'})
13+
tree.Put("foo", "one")
14+
tree.Put("bar", 2)
15+
tree.Put("baz", 3.0)
16+
tree.Put("日本語", []byte{'f', 'o', 'u', 'r'})
1817

1918
// Get a value from the trie
20-
v, ok := tree.Get(trie.StringKey("日本語"))
19+
v, ok := tree.Get("日本語")
2120
if !ok {
2221
fmt.Printf("failed to find key '日本語'\n")
2322
return
2423
}
2524
_ = v
2625

2726
// Delete a key from the trie
28-
if !tree.Delete(trie.StringKey("日本語")) {
27+
if !tree.Delete("日本語") {
2928
fmt.Printf("failed to delete key '日本語'\n")
3029
return
3130
}
3231

3332
// This time Get() should fail
34-
v, ok = tree.Get(trie.StringKey("日本語"))
33+
v, ok = tree.Get("日本語")
3534
if ok {
3635
fmt.Printf("key '日本語' should not exist\n")
3736
return
3837
}
3938
_ = v
4039

41-
ctx := context.Background()
42-
43-
// Or, walk the entire trie
44-
for p := range tree.Walk(ctx) {
45-
// Do something with the values...
46-
_ = p
47-
}
48-
4940
// OUTPUT:
5041
}

‎trie_test.go

+21-46
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,41 @@
11
package trie_test
22

33
import (
4-
"context"
5-
"fmt"
64
"testing"
75

8-
"github.com/lestrrat-go/trie"
9-
"github.com/stretchr/testify/assert"
6+
"github.com/lestrrat-go/trie/v2"
7+
"github.com/stretchr/testify/require"
108
)
119

1210
func TestTrie(t *testing.T) {
1311
t.Parallel()
1412

15-
tree := trie.New()
16-
tree.Put(trie.StringKey("foo"), 1)
17-
tree.Put(trie.StringKey("bar"), 2)
18-
tree.Put(trie.StringKey("baz"), 3)
19-
tree.Put(trie.StringKey("日本語"), 4)
20-
21-
ctx, cancel := context.WithCancel(context.Background())
22-
defer cancel()
23-
for p := range tree.Walk(ctx) {
24-
t.Logf("%#v", p)
25-
}
13+
tree := trie.New[string, rune, int](trie.String())
2614

2715
testcases := []struct {
28-
Key trie.Key
29-
Expected interface{}
30-
Missing bool
16+
Key string
17+
Value int
3118
}{
32-
{
33-
Key: trie.StringKey("foo"),
34-
Expected: 1,
35-
},
36-
{
37-
Key: trie.StringKey("日本語"),
38-
Expected: 4,
39-
},
40-
{
41-
Key: trie.StringKey("hoge"),
42-
Missing: true,
43-
},
19+
{"foo", 1},
20+
{"far", 2},
21+
{"for", 3},
22+
{"bar", 4},
23+
{"baz", 5},
4424
}
4525

4626
for _, tc := range testcases {
47-
tc := tc
48-
t.Run(fmt.Sprintf("%s", tc.Key), func(t *testing.T) {
49-
t.Parallel()
50-
v, ok := tree.Get(tc.Key)
51-
if tc.Missing {
52-
if !assert.False(t, ok, `tree.Get should return false`) {
53-
return
54-
}
55-
} else {
56-
if !assert.True(t, ok, `tree.Get should return true`) {
57-
return
58-
}
27+
tree.Put(tc.Key, tc.Value)
28+
}
5929

60-
if !assert.Equal(t, tc.Expected, v, `tree.Get should return expected value`) {
61-
return
62-
}
63-
}
30+
for _, tc := range testcases {
31+
t.Run(tc.Key, func(t *testing.T) {
32+
v, ok := tree.Get(tc.Key)
33+
require.True(t, ok, `tree.Get should return true`)
34+
require.Equal(t, tc.Value, v, `tree.Get should return expected value`)
6435
})
6536
}
37+
38+
require.True(t, tree.Delete("foo"), `tree.Delete should return true`)
39+
_, ok := tree.Get("foo")
40+
require.False(t, ok, `tree.Get should return false`)
6641
}

0 commit comments

Comments
 (0)
Please sign in to comment.