Skip to content

Commit dd538e3

Browse files
committed
brain: remove ForgetDuring and ForgetUserSince methods
We keep the information needed to implement these in terms of only message IDs elsewhere. Fixes #52. For #90.
1 parent 0ef3deb commit dd538e3

File tree

10 files changed

+51
-1393
lines changed

10 files changed

+51
-1393
lines changed

brain/braintest/braintest.go

+4-50
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ import (
1818
// If a brain cannot be created without error, new should call t.Fatal.
1919
func Test(ctx context.Context, t *testing.T, new func(context.Context) brain.Brain) {
2020
t.Run("speak", testSpeak(ctx, new(ctx)))
21-
t.Run("forgetMessage", testForgetMessage(ctx, new(ctx)))
22-
t.Run("forgetDuring", testForgetDuring(ctx, new(ctx)))
21+
t.Run("forgetMessage", testForget(ctx, new(ctx)))
2322
t.Run("combinatoric", testCombinatoric(ctx, new(ctx)))
2423
}
2524

@@ -182,11 +181,11 @@ func testSpeak(ctx context.Context, br brain.Brain) func(t *testing.T) {
182181
}
183182
}
184183

185-
// testForgetMessage tests that a brain can forget messages by ID.
186-
func testForgetMessage(ctx context.Context, br brain.Brain) func(t *testing.T) {
184+
// testForget tests that a brain can forget messages by ID.
185+
func testForget(ctx context.Context, br brain.Brain) func(t *testing.T) {
187186
return func(t *testing.T) {
188187
learn(ctx, t, br)
189-
if err := br.ForgetMessage(ctx, "kessoku", messages[0].ID); err != nil {
188+
if err := br.Forget(ctx, "kessoku", messages[0].ID); err != nil {
190189
t.Errorf("failed to forget first message: %v", err)
191190
}
192191
got := speak(ctx, t, br, "kessoku", "", 2048)
@@ -233,51 +232,6 @@ func testForgetMessage(ctx context.Context, br brain.Brain) func(t *testing.T) {
233232
}
234233
}
235234

236-
// testForgetDuring tests that a brain can forget messages in a time range.
237-
func testForgetDuring(ctx context.Context, br brain.Brain) func(t *testing.T) {
238-
return func(t *testing.T) {
239-
learn(ctx, t, br)
240-
if err := br.ForgetDuring(ctx, "kessoku", time.Unix(1, 0).Add(-time.Millisecond), time.Unix(2, 0).Add(time.Millisecond)); err != nil {
241-
t.Errorf("failed to forget: %v", err)
242-
}
243-
got := speak(ctx, t, br, "kessoku", "", 2048)
244-
want := map[string]struct{}{
245-
"1#member bocchi": {},
246-
"1 4#member bocchi": {},
247-
"1 4#member kita": {},
248-
"4#member kita": {},
249-
}
250-
if diff := cmp.Diff(want, got); diff != "" {
251-
t.Errorf("wrong messages after forgetting (+got/-want):\n%s", diff)
252-
}
253-
got = speak(ctx, t, br, "sickhack", "", 2048)
254-
want = map[string]struct{}{
255-
"5#member bocchi": {},
256-
"5 6#member bocchi": {},
257-
"5 7#member bocchi": {},
258-
"5 8#member bocchi": {},
259-
"5 6#member ryou": {},
260-
"6#member ryou": {},
261-
"6 7#member ryou": {},
262-
"6 8#member ryou": {},
263-
"5 7#member nijika": {},
264-
"6 7#member nijika": {},
265-
"7#member nijika": {},
266-
"7 8#member nijika": {},
267-
"5 8#member kita": {},
268-
"6 8#member kita": {},
269-
"7 8#member kita": {},
270-
"8#member kita": {},
271-
"9#manager seika": {},
272-
}
273-
if diff := cmp.Diff(want, got); diff != "" {
274-
t.Errorf("wrong spoken messages for sickhack (+got/-want):\n%s", diff)
275-
}
276-
}
277-
}
278-
279-
// TODO(zeph): testForgetUser
280-
281235
// testCombinatoric tests that chains can generate even with substantial
282236
// overlap in learned material.
283237
func testCombinatoric(ctx context.Context, br brain.Brain) func(t *testing.T) {

brain/braintest/braintest_test.go

+1-43
Original file line numberDiff line numberDiff line change
@@ -66,55 +66,13 @@ func (m *membrain) forgetIDLocked(tag, id string) {
6666
}
6767
}
6868

69-
func (m *membrain) Forget(ctx context.Context, tag string, tuples []brain.Tuple) error {
70-
m.mu.Lock()
71-
defer m.mu.Unlock()
72-
for _, tup := range tuples {
73-
p := strings.Join(tup.Prefix, "\xff")
74-
u := m.tups[tag][p]
75-
k := slices.IndexFunc(u, func(v [2]string) bool { return v[1] == tup.Suffix })
76-
if k < 0 {
77-
continue
78-
}
79-
u[k], u[len(u)-1] = u[len(u)-1], u[k]
80-
m.tups[tag][p] = u[:len(u)-1]
81-
}
82-
return nil
83-
}
84-
85-
func (m *membrain) ForgetMessage(ctx context.Context, tag, id string) error {
69+
func (m *membrain) Forget(ctx context.Context, tag, id string) error {
8670
m.mu.Lock()
8771
defer m.mu.Unlock()
8872
m.forgetIDLocked(tag, id)
8973
return nil
9074
}
9175

92-
func (m *membrain) ForgetDuring(ctx context.Context, tag string, since, before time.Time) error {
93-
m.mu.Lock()
94-
defer m.mu.Unlock()
95-
s, b := since.UnixNano(), before.UnixNano()
96-
for tm, u := range m.tms[tag] {
97-
if tm < s || tm > b {
98-
continue
99-
}
100-
for _, v := range u {
101-
m.forgetIDLocked(tag, v)
102-
}
103-
delete(m.tms[tag], tm) // yea i modify the map during iteration, yea i'm cool
104-
}
105-
return nil
106-
}
107-
108-
func (m *membrain) ForgetUser(ctx context.Context, user *userhash.Hash) error {
109-
m.mu.Lock()
110-
defer m.mu.Unlock()
111-
for _, v := range m.users[*user] {
112-
m.forgetIDLocked(v[0], v[1])
113-
}
114-
delete(m.users, *user)
115-
return nil
116-
}
117-
11876
func (m *membrain) Speak(ctx context.Context, tag string, prompt []string, w *brain.Builder) error {
11977
m.mu.Lock()
12078
defer m.mu.Unlock()

brain/kvbrain/forget.go

+2-89
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"fmt"
77
"slices"
88
"sync"
9-
"time"
109

1110
"github.com/zephyrtronium/robot/userhash"
1211
)
@@ -53,45 +52,9 @@ func (p *past) findID(id string) [][]byte {
5352
return nil
5453
}
5554

56-
// findDuring finds all knowledge keys of messages recorded with timestamps in
57-
// the given time span.
58-
func (p *past) findDuring(since, before int64) [][]byte {
59-
r := make([][]byte, 0, 64)
60-
p.mu.Lock()
61-
defer p.mu.Unlock()
62-
for k, v := range p.time {
63-
if since <= v && v <= before {
64-
keys := p.key[k]
65-
r = slices.Grow(r, len(keys))
66-
for _, v := range keys {
67-
r = append(r, bytes.Clone(v))
68-
}
69-
}
70-
}
71-
return r
72-
}
73-
74-
// findUser finds all knowledge keys of messages recorded from a given user
75-
// since a timestamp.
76-
func (p *past) findUser(user userhash.Hash) [][]byte {
77-
r := make([][]byte, 0, 64)
78-
p.mu.Lock()
79-
defer p.mu.Unlock()
80-
for k, v := range p.user {
81-
if v == user {
82-
keys := p.key[k]
83-
r = slices.Grow(r, len(keys))
84-
for _, v := range keys {
85-
r = append(r, bytes.Clone(v))
86-
}
87-
}
88-
}
89-
return r
90-
}
91-
92-
// ForgetMessage forgets everything learned from a single given message.
55+
// Forget forgets everything learned from a single given message.
9356
// If nothing has been learned from the message, it should be ignored.
94-
func (br *Brain) ForgetMessage(ctx context.Context, tag, id string) error {
57+
func (br *Brain) Forget(ctx context.Context, tag, id string) error {
9558
past, _ := br.past.Load(tag)
9659
if past == nil {
9760
return nil
@@ -111,53 +74,3 @@ func (br *Brain) ForgetMessage(ctx context.Context, tag, id string) error {
11174
}
11275
return nil
11376
}
114-
115-
// ForgetDuring forgets all messages learned in the given time span.
116-
func (br *Brain) ForgetDuring(ctx context.Context, tag string, since, before time.Time) error {
117-
past, _ := br.past.Load(tag)
118-
if past == nil {
119-
return nil
120-
}
121-
keys := past.findDuring(since.UnixNano(), before.UnixNano())
122-
batch := br.knowledge.NewWriteBatch()
123-
defer batch.Cancel()
124-
for _, key := range keys {
125-
err := batch.Delete(key)
126-
if err != nil {
127-
return err
128-
}
129-
}
130-
err := batch.Flush()
131-
if err != nil {
132-
return fmt.Errorf("couldn't commit deleting between times %v and %v: %w", since, before, err)
133-
}
134-
return nil
135-
}
136-
137-
// ForgetUser forgets all messages associated with a userhash.
138-
func (br *Brain) ForgetUser(ctx context.Context, user *userhash.Hash) error {
139-
var rangeErr error
140-
u := *user
141-
br.past.Range(func(tag string, past *past) bool {
142-
keys := past.findUser(u)
143-
if len(keys) == 0 {
144-
return true
145-
}
146-
batch := br.knowledge.NewWriteBatch()
147-
defer batch.Cancel()
148-
for _, key := range keys {
149-
err := batch.Delete(key)
150-
if err != nil {
151-
rangeErr = err
152-
return false
153-
}
154-
}
155-
err := batch.Flush()
156-
if err != nil {
157-
rangeErr = fmt.Errorf("couldn't commit deleting messages by user: %w", err)
158-
return false
159-
}
160-
return false
161-
})
162-
return rangeErr
163-
}

0 commit comments

Comments
 (0)