diff --git a/error.go b/error.go index 9b348193a..a7bf159c2 100644 --- a/error.go +++ b/error.go @@ -38,6 +38,15 @@ type Error interface { var _ Error = proto.RedisError("") +func isContextError(err error) bool { + switch err { + case context.Canceled, context.DeadlineExceeded: + return true + default: + return false + } +} + func shouldRetry(err error, retryTimeout bool) bool { switch err { case io.EOF, io.ErrUnexpectedEOF: diff --git a/osscluster.go b/osscluster.go index 517fbd450..1e9ee7de4 100644 --- a/osscluster.go +++ b/osscluster.go @@ -1350,7 +1350,9 @@ func (c *ClusterClient) processPipelineNode( _ = node.Client.withProcessPipelineHook(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { cn, err := node.Client.getConn(ctx) if err != nil { - node.MarkAsFailing() + if !isContextError(err) { + node.MarkAsFailing() + } _ = c.mapCmdsByNode(ctx, failedCmds, cmds) setCmdsErr(cmds, err) return err diff --git a/osscluster_test.go b/osscluster_test.go index aeb34c6bd..ccf6daad8 100644 --- a/osscluster_test.go +++ b/osscluster_test.go @@ -539,6 +539,39 @@ var _ = Describe("ClusterClient", func() { AfterEach(func() {}) assertPipeline() + + It("doesn't fail node with context.Canceled error", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + pipe.Set(ctx, "A", "A_value", 0) + _, err := pipe.Exec(ctx) + + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, context.Canceled)).To(BeTrue()) + + clientNodes, _ := client.Nodes(ctx, "A") + + for _, node := range clientNodes { + Expect(node.Failing()).To(BeFalse()) + } + }) + + It("doesn't fail node with context.DeadlineExceeded error", func() { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + + pipe.Set(ctx, "A", "A_value", 0) + _, err := pipe.Exec(ctx) + + Expect(err).To(HaveOccurred()) + Expect(errors.Is(err, context.DeadlineExceeded)).To(BeTrue()) + + clientNodes, _ := client.Nodes(ctx, "A") + + for _, node := range clientNodes { + Expect(node.Failing()).To(BeFalse()) + } + }) }) Describe("with TxPipeline", func() {