Skip to content

Commit 1c26481

Browse files
committed
brain/*: learn entire messages
Now brain implementations will have no excuse not to record message text in addition to tuples. For #90.
1 parent f6dfe4e commit 1c26481

14 files changed

+101
-63
lines changed

brain/brain.go

+8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
package brain
22

3+
import (
4+
"github.com/zephyrtronium/robot/message"
5+
"github.com/zephyrtronium/robot/userhash"
6+
)
7+
38
// Brain is a combined [Learner] and [Speaker].
49
type Brain interface {
510
Learner
611
Speaker
712
}
13+
14+
// Message is the message type used by a [Brain].
15+
type Message = message.Received[userhash.Hash]

brain/braintest/bench.go

+10-6
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"strconv"
1010
"strings"
1111
"testing"
12-
"time"
1312

1413
"github.com/zephyrtronium/robot/brain"
1514
"github.com/zephyrtronium/robot/userhash"
@@ -45,7 +44,8 @@ func BenchLearn(ctx context.Context, b *testing.B, new func(ctx context.Context,
4544
toks[len(toks)-1] = strconv.FormatInt(t, 10)
4645
id := randid()
4746
u := userhash.Hash(randbytes(make([]byte, len(userhash.Hash{}))))
48-
err := brain.Learn(ctx, l, "bocchi", id, u, time.Unix(t, 0), strings.Join(toks, " "))
47+
msg := brain.Message{ID: id, Sender: u, Timestamp: t * 1e3, Text: strings.Join(toks, " ")}
48+
err := brain.Learn(ctx, l, "bocchi", &msg)
4949
if err != nil {
5050
b.Errorf("error while learning: %v", err)
5151
}
@@ -83,7 +83,8 @@ func BenchLearn(ctx context.Context, b *testing.B, new func(ctx context.Context,
8383
rand.Shuffle(len(toks), func(i, j int) { toks[i], toks[j] = toks[j], toks[i] })
8484
id := randid()
8585
u := userhash.Hash(randbytes(make([]byte, len(userhash.Hash{}))))
86-
err := brain.Learn(ctx, l, "bocchi", id, u, time.Unix(t, 0), strings.Join(toks[:8], " "))
86+
msg := brain.Message{ID: id, Sender: u, Timestamp: t * 1e3, Text: strings.Join(toks, " ")}
87+
err := brain.Learn(ctx, l, "bocchi", &msg)
8788
if err != nil {
8889
b.Errorf("error while learning: %v", err)
8990
}
@@ -117,7 +118,8 @@ func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context,
117118
toks[len(toks)-1] = strconv.FormatInt(t, 10)
118119
id := randid()
119120
u := userhash.Hash(randbytes(make([]byte, len(userhash.Hash{}))))
120-
err := brain.Learn(ctx, br, "bocchi", id, u, time.Unix(t, 0), strings.Join(toks, " "))
121+
msg := brain.Message{ID: id, Sender: u, Timestamp: t * 1e3, Text: strings.Join(toks, " ")}
122+
err := brain.Learn(ctx, br, "bocchi", &msg)
121123
if err != nil {
122124
b.Errorf("error while learning: %v", err)
123125
}
@@ -162,7 +164,8 @@ func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context,
162164
rand.Shuffle(len(toks), func(i, j int) { toks[i], toks[j] = toks[j], toks[i] })
163165
id := randid()
164166
u := userhash.Hash(randbytes(make([]byte, len(userhash.Hash{}))))
165-
err := brain.Learn(ctx, br, "bocchi", id, u, time.Unix(t, 0), strings.Join(toks, " "))
167+
msg := brain.Message{ID: id, Sender: u, Timestamp: t * 1e3, Text: strings.Join(toks, " ")}
168+
err := brain.Learn(ctx, br, "bocchi", &msg)
166169
if err != nil {
167170
b.Errorf("error while learning: %v", err)
168171
}
@@ -207,7 +210,8 @@ func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context,
207210
rand.Shuffle(len(toks), func(i, j int) { toks[i], toks[j] = toks[j], toks[i] })
208211
id := randid()
209212
u := userhash.Hash(randbytes(make([]byte, len(userhash.Hash{}))))
210-
err := brain.Learn(ctx, br, "bocchi", id, u, time.Unix(t, 0), strings.Join(toks, " "))
213+
msg := brain.Message{ID: id, Sender: u, Timestamp: t * 1e3, Text: strings.Join(toks, " ")}
214+
err := brain.Learn(ctx, br, "bocchi", &msg)
211215
if err != nil {
212216
b.Errorf("error while learning: %v", err)
213217
}

brain/braintest/braintest.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ var messages = [...]struct {
9797
func learn(ctx context.Context, t *testing.T, br brain.Learner) {
9898
t.Helper()
9999
for _, m := range messages {
100-
if err := brain.Learn(ctx, br, m.Tag, m.ID, m.User, m.Time, m.Text); err != nil {
100+
msg := brain.Message{ID: m.ID, Sender: m.User, Timestamp: m.Time.UnixMilli(), Text: m.Text}
101+
if err := brain.Learn(ctx, br, m.Tag, &msg); err != nil {
101102
t.Fatalf("couldn't learn message %v: %v", m.ID, err)
102103
}
103104
}
@@ -242,7 +243,8 @@ func testCombinatoric(ctx context.Context, br brain.Brain) func(t *testing.T) {
242243
toks := toks
243244
for len(toks) > 1 {
244245
id := randid()
245-
err := brain.Learn(ctx, br, "bocchi", id, u, time.Unix(0, 0), strings.Join(toks, " "))
246+
msg := brain.Message{ID: id, Sender: u, Text: strings.Join(toks, " ")}
247+
err := brain.Learn(ctx, br, "bocchi", &msg)
246248
if err != nil {
247249
t.Fatalf("couldn't learn init: %v", err)
248250
}

brain/braintest/braintest_test.go

+4-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"strings"
88
"sync"
99
"testing"
10-
"time"
1110

1211
"github.com/zephyrtronium/robot/brain"
1312
"github.com/zephyrtronium/robot/brain/braintest"
@@ -25,7 +24,7 @@ type membrain struct {
2524

2625
var _ brain.Brain = (*membrain)(nil)
2726

28-
func (m *membrain) Learn(ctx context.Context, tag, id string, user userhash.Hash, t time.Time, tuples []brain.Tuple) error {
27+
func (m *membrain) Learn(ctx context.Context, tag string, msg *brain.Message, tuples []brain.Tuple) error {
2928
m.mu.Lock()
3029
defer m.mu.Unlock()
3130
if m.tups[tag] == nil {
@@ -37,13 +36,13 @@ func (m *membrain) Learn(ctx context.Context, tag, id string, user userhash.Hash
3736
m.tups[tag] = make(map[string][][2]string)
3837
m.tms[tag] = make(map[int64][]string)
3938
}
40-
m.users[user] = append(m.users[user], [2]string{tag, id})
39+
m.users[msg.Sender] = append(m.users[msg.Sender], [2]string{tag, msg.ID})
4140
tms := m.tms[tag]
42-
tms[t.UnixNano()] = append(tms[t.UnixNano()], id)
41+
tms[msg.Timestamp] = append(tms[msg.Timestamp], msg.ID)
4342
r := m.tups[tag]
4443
for _, tup := range tuples {
4544
p := strings.Join(tup.Prefix, "\xff")
46-
r[p] = append(r[p], [2]string{id, tup.Suffix})
45+
r[p] = append(r[p], [2]string{msg.ID, tup.Suffix})
4746
}
4847
return nil
4948
}

brain/kvbrain/forget_test.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,12 @@ func TestForget(t *testing.T) {
203203
}
204204
br := New(db)
205205
for _, msg := range c.msgs {
206-
err := br.Learn(ctx, msg.tag, msg.id, msg.user, msg.time, msg.tups)
206+
m := brain.Message{
207+
ID: msg.id,
208+
Sender: msg.user,
209+
Timestamp: msg.time.UnixMilli(),
210+
}
211+
err := br.Learn(ctx, msg.tag, &m, msg.tups)
207212
if err != nil {
208213
t.Errorf("failed to learn: %v", err)
209214
}

brain/kvbrain/learn.go

+4-5
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,16 @@ import (
55
"context"
66
"errors"
77
"fmt"
8-
"time"
98

109
"github.com/zephyrtronium/robot/brain"
11-
"github.com/zephyrtronium/robot/userhash"
1210
)
1311

1412
// Learn records a set of tuples. Each tuple prefix has length equal to the
1513
// result of Order. The tuples begin with empty strings in the prefix to
1614
// denote the start of the message and end with one empty suffix to denote
1715
// the end; all other tokens are non-empty. Each tuple's prefix has entropy
1816
// reduction transformations applied.
19-
func (br *Brain) Learn(ctx context.Context, tag, id string, user userhash.Hash, t time.Time, tuples []brain.Tuple) error {
17+
func (br *Brain) Learn(ctx context.Context, tag string, msg *brain.Message, tuples []brain.Tuple) error {
2018
if len(tuples) == 0 {
2119
return errors.New("no tuples to learn")
2220
}
@@ -31,7 +29,7 @@ func (br *Brain) Learn(ctx context.Context, tag, id string, user userhash.Hash,
3129
b = hashTag(b[:0], tag)
3230
b = append(appendPrefix(b, t.Prefix), '\xff')
3331
// Write message ID.
34-
b = append(b, id[:]...)
32+
b = append(b, msg.ID...)
3533
keys[i] = bytes.Clone(b)
3634
vals[i] = []byte(t.Suffix)
3735
}
@@ -42,7 +40,8 @@ func (br *Brain) Learn(ctx context.Context, tag, id string, user userhash.Hash,
4240
// overwrite if that happens.
4341
p, _ = br.past.LoadOrStore(tag, new(past))
4442
}
45-
p.record(id, user, t.UnixNano(), keys)
43+
// Scale the timestamp from milliseconds to nanoseconds for historical reasons.
44+
p.record(msg.ID, msg.Sender, msg.Timestamp*1e6, keys)
4645

4746
batch := br.knowledge.NewWriteBatch()
4847
defer batch.Cancel()

brain/kvbrain/learn_test.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,12 @@ func TestLearn(t *testing.T) {
128128
t.Fatal(err)
129129
}
130130
br := New(db)
131-
if err := br.Learn(ctx, c.tag, c.id, c.user, c.time, c.tups); err != nil {
131+
msg := brain.Message{
132+
ID: c.id,
133+
Sender: c.user,
134+
Timestamp: c.time.UnixMilli(),
135+
}
136+
if err := br.Learn(ctx, c.tag, &msg, c.tups); err != nil {
132137
t.Errorf("failed to learn: %v", err)
133138
}
134139
dbcheck(t, db, c.want)

brain/learn.go

+5-7
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@ package brain
33
import (
44
"context"
55
"slices"
6-
"time"
76

87
"github.com/zephyrtronium/robot/tpool"
9-
"github.com/zephyrtronium/robot/userhash"
108
)
119

1210
// Tuple is a single Markov chain tuple.
@@ -26,7 +24,7 @@ type Learner interface {
2624
// of the message. The positions of each in the argument are not guaranteed.
2725
// Each tuple's prefix has entropy reduction transformations applied.
2826
// Tuples in the argument may share storage for prefixes.
29-
Learn(ctx context.Context, tag, id string, user userhash.Hash, t time.Time, tuples []Tuple) error
27+
Learn(ctx context.Context, tag string, msg *Message, tuples []Tuple) error
3028
// Forget forgets everything learned from a single given message.
3129
// If nothing has been learned from the message, it should prevent anything
3230
// from being learned from a message with that ID.
@@ -35,9 +33,9 @@ type Learner interface {
3533

3634
var tuplesPool tpool.Pool[[]Tuple]
3735

38-
// Learn records tokens into a Learner.
39-
func Learn(ctx context.Context, l Learner, tag, id string, user userhash.Hash, t time.Time, text string) error {
40-
toks := Tokens(tokensPool.Get(), text)
36+
// Learn records a message into a Learner.
37+
func Learn(ctx context.Context, l Learner, tag string, msg *Message) error {
38+
toks := Tokens(tokensPool.Get(), msg.Text)
4139
defer func() { tokensPool.Put(toks[:0]) }()
4240
if len(toks) == 0 {
4341
return nil
@@ -46,7 +44,7 @@ func Learn(ctx context.Context, l Learner, tag, id string, user userhash.Hash, t
4644
defer func() { tuplesPool.Put(tt[:0]) }()
4745
tt = slices.Grow(tt, len(toks)+1)
4846
tt = tupleToks(tt, toks)
49-
return l.Learn(ctx, tag, id, user, t, tt)
47+
return l.Learn(ctx, tag, msg, tt)
5048
}
5149

5250
func tupleToks(tt []Tuple, toks []string) []Tuple {

brain/learn_test.go

+2-4
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,17 @@ package brain_test
33
import (
44
"context"
55
"testing"
6-
"time"
76

87
"github.com/google/go-cmp/cmp"
98

109
"github.com/zephyrtronium/robot/brain"
11-
"github.com/zephyrtronium/robot/userhash"
1210
)
1311

1412
type testLearner struct {
1513
learned []brain.Tuple
1614
}
1715

18-
func (t *testLearner) Learn(ctx context.Context, tag, id string, user userhash.Hash, tm time.Time, tuples []brain.Tuple) error {
16+
func (t *testLearner) Learn(ctx context.Context, tag string, msg *brain.Message, tuples []brain.Tuple) error {
1917
t.learned = append(t.learned, tuples...)
2018
return nil
2119
}
@@ -63,7 +61,7 @@ func TestLearn(t *testing.T) {
6361
for _, c := range cases {
6462
t.Run(c.name, func(t *testing.T) {
6563
var l testLearner
66-
err := brain.Learn(context.Background(), &l, "", "", userhash.Hash{}, time.Unix(0, 0), c.msg)
64+
err := brain.Learn(context.Background(), &l, "", &brain.Message{Text: c.msg})
6765
if err != nil {
6866
t.Error(err)
6967
}

brain/sqlbrain/forget_test.go

+18-14
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"strings"
66
"testing"
7-
"time"
87

98
"github.com/zephyrtronium/robot/brain"
109
"github.com/zephyrtronium/robot/brain/sqlbrain"
@@ -107,19 +106,19 @@ func TestForget(t *testing.T) {
107106
{
108107
tag: "kessoku",
109108
id: "2",
110-
time: 3,
109+
time: 3e6,
111110
user: userhash.Hash{1},
112111
},
113112
{
114113
tag: "kessoku",
115114
id: "5",
116-
time: 6,
115+
time: 6e6,
117116
user: userhash.Hash{4},
118117
},
119118
{
120119
tag: "sickhack",
121120
id: "2",
122-
time: 3,
121+
time: 3e6,
123122
user: userhash.Hash{1},
124123
},
125124
}
@@ -206,20 +205,20 @@ func TestForget(t *testing.T) {
206205
{
207206
tag: "kessoku",
208207
id: "2",
209-
time: 3,
208+
time: 3e6,
210209
user: userhash.Hash{1},
211210
deleted: ref("CLEARMSG"),
212211
},
213212
{
214213
tag: "kessoku",
215214
id: "5",
216-
time: 6,
215+
time: 6e6,
217216
user: userhash.Hash{4},
218217
},
219218
{
220219
tag: "sickhack",
221220
id: "2",
222-
time: 3,
221+
time: 3e6,
223222
user: userhash.Hash{1},
224223
},
225224
},
@@ -290,20 +289,20 @@ func TestForget(t *testing.T) {
290289
{
291290
tag: "kessoku",
292291
id: "2",
293-
time: 3,
292+
time: 3e6,
294293
user: userhash.Hash{1},
295294
},
296295
{
297296
tag: "kessoku",
298297
id: "5",
299-
time: 6,
298+
time: 6e6,
300299
user: userhash.Hash{4},
301300
deleted: ref("CLEARMSG"),
302301
},
303302
{
304303
tag: "sickhack",
305304
id: "2",
306-
time: 3,
305+
time: 3e6,
307306
user: userhash.Hash{1},
308307
},
309308
},
@@ -374,19 +373,19 @@ func TestForget(t *testing.T) {
374373
{
375374
tag: "kessoku",
376375
id: "2",
377-
time: 3,
376+
time: 3e6,
378377
user: userhash.Hash{1},
379378
},
380379
{
381380
tag: "kessoku",
382381
id: "5",
383-
time: 6,
382+
time: 6e6,
384383
user: userhash.Hash{4},
385384
},
386385
{
387386
tag: "sickhack",
388387
id: "2",
389-
time: 3,
388+
time: 3e6,
390389
user: userhash.Hash{1},
391390
deleted: ref("CLEARMSG"),
392391
},
@@ -403,7 +402,12 @@ func TestForget(t *testing.T) {
403402
t.Fatalf("couldn't open brain: %v", err)
404403
}
405404
for _, m := range learn {
406-
err := br.Learn(ctx, m.tag, m.id, m.user, time.Unix(0, m.t), m.tups)
405+
msg := brain.Message{
406+
ID: m.id,
407+
Sender: m.user,
408+
Timestamp: m.t,
409+
}
410+
err := br.Learn(ctx, m.tag, &msg, m.tups)
407411
if err != nil {
408412
t.Errorf("failed to learn %v/%v: %v", m.tag, m.id, err)
409413
}

0 commit comments

Comments
 (0)