diff --git a/channellib.go b/channellib.go index 7c3c3709..a92bf72c 100644 --- a/channellib.go +++ b/channellib.go @@ -83,7 +83,20 @@ func channelSelect(L *LState) int { cases[i] = cas } + if L.ctx != nil { + cases = append(cases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(L.ctx.Done()), + Send: reflect.ValueOf(nil), + }) + } + pos, recv, rok := reflect.Select(cases) + + if L.ctx != nil && pos == L.GetTop() { + return 0 + } + lv := LNil if recv.Kind() != 0 { lv, _ = recv.Interface().(LValue) @@ -129,7 +142,22 @@ var channelMethods = map[string]LGFunction{ func channelReceive(L *LState) int { rch := checkChannel(L, 1) - v, ok := rch.Recv() + var v reflect.Value + var ok bool + if L.ctx != nil { + cases := []reflect.SelectCase{{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(L.ctx.Done()), + Send: reflect.ValueOf(nil), + }, { + Dir: reflect.SelectRecv, + Chan: rch, + Send: reflect.ValueOf(nil), + }} + _, v, ok = reflect.Select(cases) + } else { + v, ok = rch.Recv() + } if ok { L.Push(LTrue) L.Push(v.Interface().(LValue)) diff --git a/channellib_test.go b/channellib_test.go index a7fbffce..883dcea5 100644 --- a/channellib_test.go +++ b/channellib_test.go @@ -1,6 +1,7 @@ package lua import ( + "context" "reflect" "sync" "testing" @@ -259,3 +260,35 @@ func TestChannelSendReceive1(t *testing.T) { go sender(ch) wg.Wait() } + +func TestCancelChannelReceive(t *testing.T) { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer close(done) + L := NewState() + L.SetContext(ctx) + defer L.Close() + L.SetGlobal("ch", LChannel(make(chan LValue))) + errorIfScriptNotFail(t, L, `ch:receive()`, context.Canceled.Error()) + }() + time.Sleep(time.Second) + cancel() + <-done +} + +func TestCancelChannelReceive2(t *testing.T) { + done := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + go func() { + defer close(done) + L := NewState() + L.SetContext(ctx) + defer L.Close() + L.SetGlobal("ch", LChannel(make(chan LValue))) + errorIfScriptNotFail(t, L, `channel.select({"|<-", ch})`, context.Canceled.Error()) + }() + time.Sleep(time.Second) + cancel() + <-done +}