Skip to content

Commit 267dde3

Browse files
committed
brain/sqlbrain: implement recollection
For #90.
1 parent 113b2e3 commit 267dde3

File tree

3 files changed

+285
-1
lines changed

3 files changed

+285
-1
lines changed

brain/sqlbrain/learn.go

+85-1
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@ package sqlbrain
22

33
import (
44
"context"
5+
_ "embed"
56
"fmt"
7+
"strconv"
68

79
"zombiezen.com/go/sqlite/sqlitex"
810

911
"github.com/zephyrtronium/robot/brain"
12+
"github.com/zephyrtronium/robot/userhash"
1013
)
1114

1215
// Learn records a set of tuples.
@@ -62,6 +65,87 @@ func prefix(b []byte, tup []string) []byte {
6265
return b
6366
}
6467

68+
// Recall fills out with messages read from the brain.
6569
func (br *Brain) Recall(ctx context.Context, tag string, page string, out []brain.Message) (n int, next string, err error) {
66-
panic("unimplemented")
70+
t, s, err := pageparams(page)
71+
if err != nil {
72+
return 0, "", err
73+
}
74+
75+
conn, err := br.db.Take(ctx)
76+
defer br.db.Put(conn)
77+
if err != nil {
78+
return 0, "", fmt.Errorf("couldn't get connection to recall: %w", err)
79+
}
80+
81+
st, err := conn.Prepare(recallQuery)
82+
if err != nil {
83+
return 0, "", fmt.Errorf("couldn't prepare recollection: %w", err)
84+
}
85+
st.SetText(":tag", tag)
86+
st.SetInt64(":startTime", t)
87+
st.SetText(":startID", s)
88+
st.SetInt64(":n", int64(len(out)))
89+
for i := range out {
90+
ok, err := st.Step()
91+
if err != nil {
92+
return 0, page, fmt.Errorf("couldn't step recollection: %w", err)
93+
}
94+
if !ok {
95+
out = out[:i]
96+
break
97+
}
98+
var u userhash.Hash
99+
s = st.ColumnText(0)
100+
t = st.ColumnInt64(1)
101+
st.ColumnBytes(2, u[:])
102+
out[i] = brain.Message{
103+
ID: s,
104+
Timestamp: t / 1e6, // convert ns to ms
105+
Sender: u,
106+
Text: st.ColumnText(3),
107+
}
108+
}
109+
110+
if err = st.Reset(); err != nil {
111+
// Return the error along with our normal results below.
112+
err = fmt.Errorf("resetting recollection statement failed: %w", err)
113+
}
114+
if len(out) == 0 {
115+
// No results. Recollection has ended.
116+
// This also happens if we were given zero elements to fill,
117+
// but that's the caller's problem.
118+
return 0, "", err
119+
}
120+
return len(out), topage(t, s), err
67121
}
122+
123+
func pageparams(page string) (int64, string, error) {
124+
if page == "" {
125+
return 0, "", nil
126+
}
127+
r, err := strconv.QuotedPrefix(page)
128+
if err != nil {
129+
return 0, "", fmt.Errorf("bad page %q", page)
130+
}
131+
l := page[len(r):]
132+
t, err := strconv.ParseInt(l, 10, 64)
133+
if err != nil {
134+
return 0, "", fmt.Errorf("bad page %q", page)
135+
}
136+
id, err := strconv.Unquote(r)
137+
if err != nil {
138+
return 0, "", fmt.Errorf("bad page %q", page)
139+
}
140+
return t, id, nil
141+
}
142+
143+
func topage(t int64, id string) string {
144+
b := make([]byte, 0, 64)
145+
b = strconv.AppendQuoteToASCII(b, id)
146+
b = strconv.AppendInt(b, t, 10)
147+
return string(b)
148+
}
149+
150+
//go:embed recall.sql
151+
var recallQuery string

brain/sqlbrain/learn_test.go

+174
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,180 @@ func TestLearn(t *testing.T) {
394394
}
395395
}
396396

397+
func TestRecall(t *testing.T) {
398+
cases := []struct {
399+
name string
400+
learn []learn
401+
want []brain.Message
402+
}{
403+
{
404+
name: "empty",
405+
learn: nil,
406+
want: nil,
407+
},
408+
{
409+
name: "single",
410+
learn: []learn{
411+
{
412+
tag: "kessoku",
413+
user: userhash.Hash{0: 1, userhash.Size - 1: 2},
414+
id: "1",
415+
t: 1,
416+
tups: []brain.Tuple{
417+
{Prefix: strings.Fields("kita nijika ryo bocchi"), Suffix: ""},
418+
{Prefix: strings.Fields("nijika ryo bocchi"), Suffix: "kita"},
419+
{Prefix: strings.Fields("ryo bocchi"), Suffix: "nijika"},
420+
{Prefix: strings.Fields("bocchi"), Suffix: "ryo"},
421+
{Prefix: nil, Suffix: "bocchi"},
422+
},
423+
},
424+
},
425+
want: []brain.Message{
426+
{
427+
ID: "1",
428+
Sender: userhash.Hash{0: 1, userhash.Size - 1: 2},
429+
Text: "bocchiryonijikakita",
430+
Timestamp: 1,
431+
},
432+
},
433+
},
434+
{
435+
name: "several",
436+
learn: []learn{
437+
{
438+
tag: "kessoku",
439+
user: userhash.Hash{0: 1, userhash.Size - 1: 2},
440+
id: "1",
441+
t: 1,
442+
tups: []brain.Tuple{
443+
{Prefix: []string{"bocchi"}, Suffix: ""},
444+
{Prefix: nil, Suffix: "bocchi"},
445+
},
446+
},
447+
{
448+
tag: "kessoku",
449+
user: userhash.Hash{0: 3, userhash.Size - 1: 4},
450+
id: "2",
451+
t: 1, // n.b. same timestamp
452+
tups: []brain.Tuple{
453+
{Prefix: []string{"ryo"}, Suffix: ""},
454+
{Prefix: nil, Suffix: "ryo"},
455+
},
456+
},
457+
{
458+
tag: "kessoku",
459+
user: userhash.Hash{0: 5, userhash.Size - 1: 6},
460+
id: "0", // n.b. lexicographically smaller id
461+
t: 3,
462+
tups: []brain.Tuple{
463+
{Prefix: []string{"nijika"}, Suffix: ""},
464+
{Prefix: nil, Suffix: "nijika"},
465+
},
466+
},
467+
{
468+
tag: "kessoku",
469+
user: userhash.Hash{0: 7, userhash.Size - 1: 8},
470+
id: "4",
471+
t: 4,
472+
tups: []brain.Tuple{
473+
{Prefix: []string{"kita"}, Suffix: ""},
474+
{Prefix: nil, Suffix: "kita"},
475+
},
476+
},
477+
},
478+
want: []brain.Message{
479+
{
480+
ID: "1",
481+
Sender: userhash.Hash{0: 1, userhash.Size - 1: 2},
482+
Text: "bocchi",
483+
Timestamp: 1,
484+
},
485+
{
486+
ID: "2",
487+
Sender: userhash.Hash{0: 3, userhash.Size - 1: 4},
488+
Text: "ryo",
489+
Timestamp: 1,
490+
},
491+
{
492+
ID: "0",
493+
Sender: userhash.Hash{0: 5, userhash.Size - 1: 6},
494+
Text: "nijika",
495+
Timestamp: 3,
496+
},
497+
{
498+
ID: "4",
499+
Sender: userhash.Hash{0: 7, userhash.Size - 1: 8},
500+
Text: "kita",
501+
Timestamp: 4,
502+
},
503+
},
504+
},
505+
{
506+
name: "tagged",
507+
learn: []learn{
508+
{
509+
tag: "sickhack",
510+
user: userhash.Hash{0: 1, userhash.Size - 1: 2},
511+
id: "1",
512+
t: 1,
513+
tups: []brain.Tuple{
514+
{Prefix: strings.Fields("kita nijika ryo bocchi"), Suffix: ""},
515+
{Prefix: strings.Fields("nijika ryo bocchi"), Suffix: "kita"},
516+
{Prefix: strings.Fields("ryo bocchi"), Suffix: "nijika"},
517+
{Prefix: strings.Fields("bocchi"), Suffix: "ryo"},
518+
{Prefix: nil, Suffix: "bocchi"},
519+
},
520+
},
521+
},
522+
want: nil,
523+
},
524+
}
525+
for _, c := range cases {
526+
t.Run(c.name, func(t *testing.T) {
527+
t.Parallel()
528+
ctx := context.Background()
529+
db := testDB(ctx)
530+
br, err := sqlbrain.Open(ctx, db)
531+
if err != nil {
532+
t.Fatalf("couldn't open brain: %v", err)
533+
}
534+
for _, m := range c.learn {
535+
msg := brain.Message{
536+
ID: m.id,
537+
Sender: m.user,
538+
Timestamp: m.t,
539+
}
540+
err := br.Learn(ctx, m.tag, &msg, m.tups)
541+
if err != nil {
542+
t.Errorf("failed to learn %v/%v: %v", m.tag, m.id, err)
543+
}
544+
}
545+
var page string
546+
got := make([]brain.Message, 1)
547+
for _, want := range c.want {
548+
n, next, err := br.Recall(ctx, "kessoku", page, got)
549+
if err != nil {
550+
t.Errorf("recall failed on page %s: %v", page, err)
551+
}
552+
if n != 1 {
553+
t.Errorf("wrong number of recalled messages for page %s: want 1, got %d", page, n)
554+
}
555+
if got[0] != want {
556+
t.Errorf("wrong result: want %+v, got %+v", want, got[0])
557+
}
558+
page = next
559+
}
560+
n, next, err := br.Recall(ctx, "kessoku", page, got)
561+
if err != nil {
562+
t.Errorf("final recall failed on page %s: %v", page, err)
563+
}
564+
if n != 0 || next != "" {
565+
t.Errorf("final recall gave wrong results: want n=0 with empty next, got n=%d next=%s", n, next)
566+
}
567+
})
568+
}
569+
}
570+
397571
func BenchmarkLearn(b *testing.B) {
398572
dir := filepath.ToSlash(b.TempDir())
399573
new := func(ctx context.Context, b *testing.B) brain.Learner {

brain/sqlbrain/recall.sql

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
WITH m AS (
2+
SELECT
3+
tag,
4+
id,
5+
time,
6+
user
7+
FROM messages
8+
WHERE tag = :tag AND (time > :startTime OR time = :startTime AND id > :startID) AND deleted IS NULL
9+
ORDER BY time, id
10+
LIMIT :n
11+
), k AS (
12+
SELECT
13+
m.id,
14+
m.time,
15+
m.user,
16+
knowledge.suffix
17+
FROM m JOIN knowledge ON m.tag = knowledge.tag AND m.id = knowledge.id
18+
ORDER BY LENGTH(knowledge.prefix)
19+
)
20+
SELECT
21+
id,
22+
time,
23+
user,
24+
GROUP_CONCAT(suffix, '') AS msg
25+
FROM k
26+
GROUP BY id, time, user

0 commit comments

Comments
 (0)