From 8fa636b7549ef62d9eb0218d8b59960510ed31f2 Mon Sep 17 00:00:00 2001 From: Leo Antunes Date: Wed, 19 Oct 2022 09:37:08 +0200 Subject: [PATCH] add per shared caller function --- singleflight/singleflight.go | 48 ++++++++++++-- singleflight/singleflight_test.go | 101 ++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 5 deletions(-) diff --git a/singleflight/singleflight.go b/singleflight/singleflight.go index 8473fb7..6e62db8 100644 --- a/singleflight/singleflight.go +++ b/singleflight/singleflight.go @@ -45,7 +45,8 @@ func newPanicError(v interface{}) error { // call is an in-flight or completed singleflight.Do call type call struct { - wg sync.WaitGroup + onceWg sync.WaitGroup + othersWg sync.WaitGroup // These fields are written once before the WaitGroup is done // and are only read after the WaitGroup is done. @@ -87,7 +88,7 @@ func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, e if c, ok := g.m[key]; ok { c.dups++ g.mu.Unlock() - c.wg.Wait() + c.onceWg.Wait() if e, ok := c.err.(*panicError); ok { panic(e) @@ -97,7 +98,7 @@ func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, e return c.val, c.err, true } c := new(call) - c.wg.Add(1) + c.onceWg.Add(1) g.m[key] = c g.mu.Unlock() @@ -122,7 +123,7 @@ func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result return ch } c := &call{chans: []chan<- Result{ch}} - c.wg.Add(1) + c.onceWg.Add(1) g.m[key] = c g.mu.Unlock() @@ -131,6 +132,43 @@ func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result return ch } +// Inspired by https://github.com/golang/sync/pull/9#issuecomment-572705800 +// `singleFn` is executed only once per key and `othersFn` is executed once per additional caller. +// All callers waiting on DoShared will wait for ALL `othersFn` to finish. +func (g *Group) DoShared(key string, onceFn func() (interface{}, error), othersFn func(interface{}, error)) (v interface{}, err error) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[string]*call) + } + if c, ok := g.m[key]; ok { + c.dups++ + c.othersWg.Add(1) + g.mu.Unlock() + c.onceWg.Wait() + func() { + // TODO: deal with panics the same way as in doCall? + defer c.othersWg.Done() + othersFn(c.val, c.err) + }() + c.othersWg.Wait() + + if e, ok := c.err.(*panicError); ok { + panic(e) + } else if c.err == errGoexit { + runtime.Goexit() + } + return c.val, c.err + } + c := new(call) + c.onceWg.Add(1) + g.m[key] = c + g.mu.Unlock() + + g.doCall(c, key, onceFn) + c.othersWg.Wait() + return c.val, c.err +} + // doCall handles the single call for a key. func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) { normalReturn := false @@ -146,7 +184,7 @@ func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) { g.mu.Lock() defer g.mu.Unlock() - c.wg.Done() + c.onceWg.Done() if g.m[key] == c { delete(g.m, key) } diff --git a/singleflight/singleflight_test.go b/singleflight/singleflight_test.go index 3e51203..1bd73db 100644 --- a/singleflight/singleflight_test.go +++ b/singleflight/singleflight_test.go @@ -318,3 +318,104 @@ func TestPanicDoSharedByDoChan(t *testing.T) { t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in Do") } } + +func TestDoSharedDupSupress(t *testing.T) { + var g Group + var wg1, wg2 sync.WaitGroup + c := make(chan string, 1) + var calls int32 + fn := func() (interface{}, error) { + if atomic.AddInt32(&calls, 1) == 1 { + // First invocation. + wg1.Done() + } + v := <-c + c <- v // pump; make available for any future calls + + time.Sleep(10 * time.Millisecond) // let more goroutines enter Do + + return v, nil + } + + const n = 10 + wg1.Add(1) + for i := 0; i < n; i++ { + wg1.Add(1) + wg2.Add(1) + go func() { + defer wg2.Done() + wg1.Done() + v, err := g.DoShared("key", fn, func(interface{}, error) {}) + if err != nil { + t.Errorf("Do error: %v", err) + return + } + if s, _ := v.(string); s != "bar" { + t.Errorf("Do = %T %v; want %q", v, v, "bar") + } + }() + } + wg1.Wait() + // At least one goroutine is in fn now and all of them have at + // least reached the line before the Do. + c <- "bar" + wg2.Wait() + if got := atomic.LoadInt32(&calls); got <= 0 || got >= n { + t.Errorf("number of calls = %d; want over 0 and less than %d", got, n) + } +} + +func TestDoSharedOthersCall(t *testing.T) { + var g Group + var wg1, wg2 sync.WaitGroup + c := make(chan string, 1) + var callsOnce int32 + var callsOthers int32 + onceFn := func() (interface{}, error) { + if atomic.AddInt32(&callsOnce, 1) == 1 { + // First invocation. + wg1.Done() + } + v := <-c + c <- v // pump; make available for any future calls + + time.Sleep(10 * time.Millisecond) // let more goroutines enter Do + + return v, nil + } + + othersFn := func(interface{}, error) { + atomic.AddInt32(&callsOthers, 1) + } + + const n = 10 + wg1.Add(1) + for i := 0; i < n; i++ { + wg1.Add(1) + wg2.Add(1) + go func() { + defer wg2.Done() + wg1.Done() + v, err := g.DoShared("key", onceFn, othersFn) + if err != nil { + t.Errorf("Do error: %v", err) + return + } + if s, _ := v.(string); s != "bar" { + t.Errorf("Do = %T %v; want %q", v, v, "bar") + } + }() + } + wg1.Wait() + // At least one goroutine is in fn now and all of them have at + // least reached the line before the Do. + c <- "bar" + wg2.Wait() + gotOnce := atomic.LoadInt32(&callsOnce) + if gotOnce <= 0 || gotOnce >= n { + t.Errorf("number of calls = %d; want over 0 and less than %d", gotOnce, n) + } + if gotOthers := atomic.LoadInt32(&callsOthers); gotOthers != n-gotOnce { + t.Errorf("number of calls = %d; want %d", gotOthers, n-gotOnce) + } +}